Skip to content

[Bug] Derivatives are wrong #158

@unalmis

Description

@unalmis

An algorithm I have written which has application in stellarator optimization would be improved by replacing some matrix multiplication transforms with type 2 non-uniform fast Fourier transforms. This algorithm is forward and reverse mode differentiable. Therefore, I would like to use jax-finufft as the flatiron institute has a nice implementation of this transform. However, my testing shows that the vector Jacobian products that jax-finufft assigns to these transformations are incorrect.

For example, here is a function $f$ which uses nufft2. Below I compute $f$ multiple times at different values of its argument $\lambda$.
Clearly the derivative at the marked point should match the derivative of the best linear approximation at that point. However, the derivative computed using the vector Jacobian product from jax-finufft is incorrect by four orders of magnitude.

This is an issue with jax-finufft; when I remove nuffts from the computation and use matrix multiplication transforms instead, vanilla jax computes the derivative correctly.

Illustration

Below is an illustration with the reverse mode derivative computation.

Image

Reproduction

git clone https://github.com/PlasmaControl/DESC.git
cd DESC
git switch ku/nufft

conda create --name desc-env 'python>=3.10, <=3.13'
conda activate desc-env
pip install --editable .
pip install -r devtools/dev-requirements.txt

Now run the following command.

pytest -v --mpl -k test_binormal_drift_bounce2d

This command will run two tests which each check correctness of the computation and the derivative. The first test computes the algorithm using matrix multiplication, without any reference to jax-finufft. The second test uses nufft2 from jax-finufft with the options detailed here.

In this second test, if jax-finnuft computes the derivatives incorrectly, the test will pass. To recover the more natural behavior of the test failing when the derivatives are incorrect, one should move the code out of the with pytest.raises(AssertionError) block here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions