-
Notifications
You must be signed in to change notification settings - Fork 8
Description
An algorithm I have written which has application in stellarator optimization would be improved by replacing some matrix multiplication transforms with type 2 non-uniform fast Fourier transforms. This algorithm is forward and reverse mode differentiable. Therefore, I would like to use jax-finufft as the flatiron institute has a nice implementation of this transform. However, my testing shows that the vector Jacobian products that jax-finufft assigns to these transformations are incorrect.
For example, here is a function nufft2. Below I compute
Clearly the derivative at the marked point should match the derivative of the best linear approximation at that point. However, the derivative computed using the vector Jacobian product from jax-finufft is incorrect by four orders of magnitude.
This is an issue with jax-finufft; when I remove nuffts from the computation and use matrix multiplication transforms instead, vanilla jax computes the derivative correctly.
Illustration
Below is an illustration with the reverse mode derivative computation.
Reproduction
git clone https://github.com/PlasmaControl/DESC.git
cd DESC
git switch ku/nufft
conda create --name desc-env 'python>=3.10, <=3.13'
conda activate desc-env
pip install --editable .
pip install -r devtools/dev-requirements.txt
Now run the following command.
pytest -v --mpl -k test_binormal_drift_bounce2d
This command will run two tests which each check correctness of the computation and the derivative. The first test computes the algorithm using matrix multiplication, without any reference to jax-finufft. The second test uses nufft2 from jax-finufft with the options detailed here.
In this second test, if jax-finnuft computes the derivatives incorrectly, the test will pass. To recover the more natural behavior of the test failing when the derivatives are incorrect, one should move the code out of the with pytest.raises(AssertionError) block here.