from scipy.fftpack import fft2, ifft2, fftfreq
from numpy import newaxis, ones, array, float64

def Operators2D(N, Ro):
    """
    Helper function to pre-define the operators that appear in the
    Navier-Stokes equations via (pseudo)-spectral computation.

    Parameters
    ----------
        N : int
            number of grid points in each spatial direction

    Returns
    -------
        Laplace : Function that computes the Laplace operator spectrally
        Laplace_inverse:
            Function that computes the inverse Laplace operator spectrally
        Filt : Function that computes the aliasing filter spectrally
        J : Function that computes the Jacobian operator pseudospectrally
    """ 

    k = N * fftfreq(N)
    dx_spectral = 1j*k[:,newaxis]*ones((N,N))
    dy_spectral = 1j*k[newaxis,:]*ones((N,N))
    lap_spectral = (dx_spectral**2 + dy_spectral**2).real
    N_alias = N/3
    filt_spectral = array(N_alias**2 + lap_spectral > 0, dtype=float64)
    filt_spectral[0,0] = 0
    lap_spectral_inverse = lap_spectral.copy()
    lap_spectral_inverse[0,0] = 1 # to avoid division by zero
    lap_spectral_inverse = 1/lap_spectral_inverse
    helm_spectral_inverse = 1/(1-Ro*lap_spectral)

    def Laplace(u_spectral):
        """
        Function that computes the Laplace operator spectrally

        Parameters
        ----------
            u_spectral : two-dimensional array of shape (N, N)

        Returns
        -------
            two-dimensional array of shape (N, N)
        """ 
        return u_spectral*lap_spectral

    def Laplace_inverse(u_spectral):
        """
        Function that computes the inverse Laplace operator spectrally

        Parameters
        ----------
            u_spectral : two-dimensional array of shape (N, N)
                         The result only makes sense if u_spectral[0,0]=0

        Returns
        -------
            two-dimensional array of shape (N, N)
        """ 
        return u_spectral*lap_spectral_inverse
    
    def Helm_inverse(u_spectral):
        """
        Function that computes the inverse Helmholtz operator spectrally

        Parameters
        ----------
            u_spectral : two-dimensional array of shape (N, N)
                         The result only makes sense if u_spectral[0,0]=0

        Returns
        -------
            two-dimensional array of shape (N, N)
        """ 
        return u_spectral*helm_spectral_inverse
    
    def Filt(u_spectral):
        """
        Function that computes the aliasing filter spectrally and removes
        the mode with wavenumber 0

        Parameters
        ----------
            u_spectral : two-dimensional array of shape (N, N)

        Returns
        -------
            two-dimensional array of shape (N, N) 
        """ 
        return u_spectral*filt_spectral

    def J(psi_spectral, phi_spectral):
        """
        Function that computes the Jacobian operator pseudo-spectrally

        Parameters
        ----------
            psi_spectral : two-dimensional array of shape (N, N)
            phi_spectral : two-dimensional array of shape (N, N)

        Returns
        -------
            two-dimensional array of shape (N, N) 
        """ 
        return filt_spectral * fft2(
            ifft2(dx_spectral*psi_spectral) * ifft2(dy_spectral*phi_spectral)
            - ifft2(dy_spectral*psi_spectral) * ifft2(dx_spectral*phi_spectral))

    return Laplace, Helm_inverse, Filt, J

