JaxBackend.build_fft#

JaxBackend.build_fft(fwd_shape, inv_shape=None, inv_output_shape=None, fwd_axes=None, inv_axes=None, **kwargs)[source]#

Build forward and inverse real fourier transform functions. The returned callables have two parameters arr and out which correspond to the input and output of the Fourier transform. The methods return the output of the respective function call, regardless of out being provided or not, analogous to most numpy functions.

Parameters:
fwd_shapetuple

Input shape for the forward Fourier transform. (see compute_convolution_shapes).

inv_shapetuple

Input shape for the inverse Fourier transform.

real_dtypedtype

Data type of the forward Fourier transform.

complex_dtypedtype

Data type of the inverse Fourier transform.

inv_output_shapetuple, optional

Output shape of the inverse Fourier transform. By default fast_shape.

fftargsdict, optional

Dictionary passed to pyFFTW builders.

temp_fwdNDArray, optional

Temporary array to build the forward transform. Superseeds shape defined by fwd_shape if provided.

temp_invNDArray, optional

Temporary array to build the inverse transform. Superseeds shape defined by inv_shape if provided.

fwd_axestuple of int

Axes to perform the forward Fourier transform over.

inv_axestuple of int

Axes to perform the inverse Fourier transform over.

Returns:
tuple

Tuple of callables for forward and inverse real Fourier transform.