JaxBackend#

class JaxBackend(float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs)[source]#

Bases: NumpyFFTWBackend

A jax-based matching backend.

Methods

JaxBackend.abs(*args, **kwargs)

Compute the absolute of array elements.

JaxBackend.add(*args, **kwargs)

Element-wise addition of arrays.

JaxBackend.arange(*args, **kwargs)

Arange values in evenly spaced interval.

JaxBackend.argsort(*args, **kwargs)

Compute the indices to sort a given input array.

JaxBackend.astype(arr, dtype)

Change the datatype of arr.

JaxBackend.build_fft(fast_shape, fast_ft_shape)

Build forward and inverse real fourier transform functions.

JaxBackend.center_of_mass(arr[, cutoff])

Computes the center of mass of a numpy ndarray instance using all available elements.

JaxBackend.clip(*args, **kwargs)

Clip elements of arr.

JaxBackend.compute_convolution_shapes(...)

Computes regular, optimized and fourier convolution shape.

JaxBackend.concatenate(*args, **kwargs)

Join a sequence of objects along an existing axis.

JaxBackend.datatype_bytes(dtype)

Return the number of bytes occupied by a given datatype.

JaxBackend.device_count()

Returns the number of available GPU devices.

JaxBackend.divide(*args, **kwargs)

Element-wise division of arrays.

JaxBackend.dot(*args, **kwargs)

JaxBackend.einsum(*args, **kwargs)

Compute the einstein notation based summation.

JaxBackend.eps(dtype)

Returns the minimal difference representable by dtype.

JaxBackend.extract_center(arr, newshape)

Extract the centered portion of an array based on a new shape.

JaxBackend.fill(arr, value)

Fills arr in-place with a given value.

JaxBackend.free_cache()

Free cached objects allocated by backend.

JaxBackend.from_sharedarr(arr)

Returns an array of given shape and dtype from shared memory location.

JaxBackend.full(*args, **kwargs)

Returns an array filled with fill_value of specified shape and dtype.

JaxBackend.get_available_memory()

Returns the available memory available for computations in bytes.

JaxBackend.get_fundamental_dtype(arr)

Given an array instance, returns the corresponding fundamental python type, i.e., int, float or complex.

JaxBackend.identity(*args, **kwargs)

JaxBackend.indices(*args, **kwargs)

Creates an array representing the index grid of an input.

JaxBackend.max(*args, **kwargs)

Compute the maximum of array elements.

JaxBackend.max_filter_coordinates(...)

Identifies local maxima in score_space separated by min_distance.

JaxBackend.max_score_over_rotations(scores, ...)

Update elements in max_scores and rotations where scores is larger than max_scores with score and rotation_index, respectivelty.

JaxBackend.maximum(*args, **kwargs)

Compute the element wise maximum of arr1 and arr2.

JaxBackend.mean(*args, **kwargs)

Compute the mean of array elements.

JaxBackend.min(*args, **kwargs)

Compute the minimum of array elements.

JaxBackend.minimum(*args, **kwargs)

Compute the element wise minimum of arr1 and arr2.

JaxBackend.mod(*args, **kwargs)

Element-wise modulus of arrays.

JaxBackend.multiply(*args, **kwargs)

Element-wise multiplication of arrays.

JaxBackend.norm_scores(arr, exp_sq, sq_exp, ...)

Normalizes arr by the standard deviation ensuring numerical stability.

JaxBackend.power(*args, **kwargs)

Compute the n-th power of an array.

JaxBackend.repeat(*args, **kwargs)

Repeat each array element a specified number of times.

JaxBackend.reshape(*args, **kwargs)

JaxBackend.reverse(arr)

Reverse the order of elements in an array along all its axes.

JaxBackend.rigid_transform(arr, rotation_matrix)

Performs a rigid transformation.

JaxBackend.roll(a, shift, axis, **kwargs)

Roll array elements along a specified axis.

JaxBackend.set_device(device_index)

Context manager that sets active compute device device for operations.

JaxBackend.size(arr)

Compute the number of elements of arr.

JaxBackend.sqrt(*args, **kwargs)

Compute the square root of array elements.

JaxBackend.square(*args, **kwargs)

Compute the square of array elements.

JaxBackend.stack(*args, **kwargs)

Join a sequence of objects along a new axis.

JaxBackend.std(*args, **kwargs)

Compute the standad deviation of array elements.

JaxBackend.subtract(*args, **kwargs)

Element-wise subtraction of arrays.

JaxBackend.sum(*args, **kwargs)

Compute the sum of array elements.

JaxBackend.to_backend_array(arr)

Convert a numpy array instance to backend array type.

JaxBackend.to_cpu_array(arr)

Convert an array of a given backend to a CPU array of that backend.

JaxBackend.to_numpy_array(arr)

Convert an array of given backend to a numpy array.

JaxBackend.to_sharedarr(arr[, ...])

Converts an array to an object shared in memory.

JaxBackend.tobytes(arr)

Compute the bytestring representation of arr.

JaxBackend.topk_indices(arr, k)

Determinces the indices of largest elements.

JaxBackend.topleft_pad(arr, shape[, padval])

Returns an array that has been padded to a specified shape with a padding value at the top-left corner.

JaxBackend.transpose(arr)

Compute the transpose of arr.

JaxBackend.tril_indices(*args, **kwargs)

Compute indices of upper triangular matrix

JaxBackend.unique(*args, **kwargs)

Find the unique elements of an array.

JaxBackend.unravel_index(indices, shape)

Convert flat index to array indices.

JaxBackend.where(*args, **kwargs)

Return elements from input depending on condition.

JaxBackend.zeros(shape[, dtype])

Returns an aligned array of zeros with specified shape and dtype.