From 806f8dba7d6144607ed44bd3a135afc315701ea7 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 7 Nov 2023 11:51:57 -0500 Subject: [PATCH 1/6] Begin sketching out t3 --- pytorch_finufft/functional.py | 119 ++++++++++++++++++++++++++++++---- 1 file changed, 108 insertions(+), 11 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index cdd71bf..ead7239 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -266,8 +266,6 @@ def backward( # type: ignore[override] grad_values, None, None, - None, - None, ) @@ -383,13 +381,7 @@ def vmap( # type: ignore[override] @staticmethod def backward( # type: ignore[override] ctx: Any, grad_output: torch.Tensor - ) -> Tuple[ - Union[torch.Tensor, None], - Union[torch.Tensor, None], - None, - None, - None, - ]: + ) -> Tuple[Union[torch.Tensor, None], ...]: _i_sign = ctx.isign _mode_ordering = ctx.mode_ordering finufftkwargs = ctx.finufftkwargs @@ -450,11 +442,106 @@ def backward( # type: ignore[override] grad_points, grad_targets, None, - None, - None, ) +class FinufftType3(torch.autograd.Function): + """ + FINUFFT problem type 3 + """ + + ISIGN_DEFAULT = -1 # note: FINUFFT default is 1 + MODEORD_DEFAULT = 1 # note: FINUFFT default is 0 + + @staticmethod + def setup_context( # type: ignore[override] + ctx: Any, + inputs: Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[Dict[str, Union[int, float]]], + ], + output: Any, + ) -> None: + points, strengths, targets, finufftkwargs = inputs + if finufftkwargs is None: + finufftkwargs = {} + else: # copy to avoid mutating caller's dictionary + finufftkwargs = finufftkwargs.copy() + ctx.save_for_backward(points, strengths, targets) + ctx.isign = finufftkwargs.pop("isign", FinufftType3.ISIGN_DEFAULT) + ctx.mode_ordering = finufftkwargs.pop("modeord", FinufftType3.MODEORD_DEFAULT) + ctx.finufftkwargs = finufftkwargs + + @staticmethod + def forward( # type: ignore[override] + points: torch.Tensor, + strengths: torch.Tensor, + targets: torch.Tensor, + finufftkwargs: Optional[Dict[str, Union[int, float]]] = None, + ) -> torch.Tensor: + checks.check_devices(targets, strengths, points) + checks.check_dtypes(strengths, points, "Strengths") + checks.check_dtypes(strengths, targets, "Strengths") + # TODO size checks + + if finufftkwargs is None: + finufftkwargs = dict() + else: + finufftkwargs = finufftkwargs.copy() + + finufftkwargs.setdefault("isign", FinufftType3.ISIGN_DEFAULT) + + modeord = finufftkwargs.pop("modeord", FinufftType3.MODEORD_DEFAULT) + + points = torch.atleast_2d(points) + targets = torch.atleast_2d(targets) + + ndim = points.shape[0] + npoints = points.shape[1] + + if points.device.type != 'cpu': + raise NotImplementedError("Type 3 is not currently implemented for GPU") + + if modeord: + strengths = batch_fftshift(strengths, ndim) + + nufft_func = get_nufft_func(ndim, 3, points.device.type) + batch_dims = strengths.shape[:-1] + + finufft_out = nufft_func( + *points, + strengths.reshape(-1, npoints), + *targets, + **finufftkwargs, + ) + finufft_out = finufft_out.reshape(*batch_dims, targets.shape[-1]) + + return finufft_out + + @staticmethod + def backward( # type: ignore[override] + ctx: Any, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + _i_sign = ctx.isign + _mode_ordering = ctx.mode_ordering + finufftkwargs = ctx.finufftkwargs + + points, strengths, targets = ctx.saved_tensors + points = torch.atleast_2d(points) + targets = torch.atleast_2d(targets) + + device = points.device + ndim = points.shape[0] + + grad_points = None + grad_strengths = None + grad_targets = None + + return grad_points, grad_strengths, grad_targets, None + + def finufft_type1( points: torch.Tensor, values: torch.Tensor, @@ -534,3 +621,13 @@ def finufft_type2( """ res: torch.Tensor = FinufftType2.apply(points, targets, finufftkwargs) return res + + +def finufft_type3( + points: torch.Tensor, + strengths: torch.Tensor, + targets: torch.Tensor, + **finufftkwargs: Union[int, float], +) -> torch.Tensor: + res: torch.Tensor = FinufftType3.apply(points, strengths, targets, finufftkwargs) + return res From 906732ff0370b72b332e2e0e71da0799862ff61d Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 16 Jan 2024 17:08:35 -0500 Subject: [PATCH 2/6] Start tests, doc all failing --- docs/api.rst | 2 ++ docs/installation.rst | 6 ++-- pytorch_finufft/checks.py | 50 ++++++++++++++++++++++++-- pytorch_finufft/functional.py | 47 +++++++++++++++++++----- tests/test_t3_forward.py | 67 +++++++++++++++++++++++++++++++++++ 5 files changed, 157 insertions(+), 15 deletions(-) create mode 100644 tests/test_t3_forward.py diff --git a/docs/api.rst b/docs/api.rst index c1d1601..7b0d305 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -6,3 +6,5 @@ API Reference .. autofunction:: pytorch_finufft.functional.finufft_type1 .. autofunction:: pytorch_finufft.functional.finufft_type2 + +.. autofunction:: pytorch_finufft.functional.finufft_type3 diff --git a/docs/installation.rst b/docs/installation.rst index 49dbcd3..4a7723e 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -6,10 +6,10 @@ Pre-requistes ------------- Pytorch-FINUFFT requires either ``finufft`` *and/or* ``cufinufft`` -2.1.0 or greater. +2.2.0 or greater. -Note that currently, this version of ``cufinufft`` is unreleased -and can only be installed from source. See the relevant installation pages for +These are available via `pip` or can be built from source. +See the relevant installation pages for :external+finufft:doc:`finufft ` and :external+finufft:doc:`cufinufft `. diff --git a/pytorch_finufft/checks.py b/pytorch_finufft/checks.py index eedabae..7a34d3d 100644 --- a/pytorch_finufft/checks.py +++ b/pytorch_finufft/checks.py @@ -21,7 +21,9 @@ def check_devices(*tensors: torch.Tensor) -> None: ) -def check_dtypes(data: torch.Tensor, points: torch.Tensor, name: str) -> None: +def check_dtypes( + data: torch.Tensor, points: torch.Tensor, name: str, points_name: str = "Points" +) -> None: """ Checks that data is complex-valued and that points is real-valued of the same precision @@ -38,8 +40,8 @@ def check_dtypes(data: torch.Tensor, points: torch.Tensor, name: str) -> None: if points.dtype is not real_dtype: raise TypeError( - f"Points must have a dtype of {real_dtype} as {name.lower()} has a " - f"dtype of {complex_dtype}" + f"{points_name} must have a dtype of {real_dtype} as {name.lower()} has a " + f"dtype of {complex_dtype}, but got {points.dtype} instead" ) @@ -102,3 +104,45 @@ def check_sizes_t2(targets: torch.Tensor, points: torch.Tensor) -> None: f"For type 2 {points_dim}d FINUFFT, targets must be at " f"least a {points_dim}d tensor" ) + + +def check_sizes_t3( + points: torch.Tensor, strengths: torch.Tensor, targets: torch.Tensor +) -> None: + """ + Checks that targets and points are of the same dimension. + This is used in type3. + """ + points_len = len(points.shape) + targets_len = len(targets.shape) + + if points_len == 1: + points_dim = 1 + elif points_len == 2: + points_dim = points.shape[0] + else: + raise ValueError("The points tensor must be 1d or 2d") + + if targets_len == 1: + targets_dim = 1 + elif targets_len == 2: + targets_dim = targets.shape[0] + else: + raise ValueError("The targets tensor must be 1d or 2d") + + if targets_dim != points_dim: + raise ValueError( + "The points tensor and targets tensor must be of the same dimension!" + + f"Got {points_dim=} and {targets_dim=} instead" + ) + + if points_dim not in {1, 2, 3}: + raise ValueError( + f"Points and targets can be at most 3d, got {points_dim} instead" + ) + + n_points = points.shape[-1] + n_strengths = strengths.shape[-1] + + if n_points != n_strengths: + raise ValueError("The same number of points and strengths must be supplied") diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index ead7239..e9a6b85 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -450,7 +450,7 @@ class FinufftType3(torch.autograd.Function): FINUFFT problem type 3 """ - ISIGN_DEFAULT = -1 # note: FINUFFT default is 1 + ISIGN_DEFAULT = 1 # note: FINUFFT default is 1 MODEORD_DEFAULT = 1 # note: FINUFFT default is 0 @staticmethod @@ -483,8 +483,8 @@ def forward( # type: ignore[override] ) -> torch.Tensor: checks.check_devices(targets, strengths, points) checks.check_dtypes(strengths, points, "Strengths") - checks.check_dtypes(strengths, targets, "Strengths") - # TODO size checks + checks.check_dtypes(strengths, targets, "Strengths", points_name="Targets") + checks.check_sizes_t3(points, strengths, targets) if finufftkwargs is None: finufftkwargs = dict() @@ -493,22 +493,19 @@ def forward( # type: ignore[override] finufftkwargs.setdefault("isign", FinufftType3.ISIGN_DEFAULT) - modeord = finufftkwargs.pop("modeord", FinufftType3.MODEORD_DEFAULT) + finufftkwargs.setdefault("modeord", FinufftType3.MODEORD_DEFAULT) points = torch.atleast_2d(points) targets = torch.atleast_2d(targets) ndim = points.shape[0] npoints = points.shape[1] + batch_dims = strengths.shape[:-1] - if points.device.type != 'cpu': + if points.device.type != "cpu": raise NotImplementedError("Type 3 is not currently implemented for GPU") - if modeord: - strengths = batch_fftshift(strengths, ndim) - nufft_func = get_nufft_func(ndim, 3, points.device.type) - batch_dims = strengths.shape[:-1] finufft_out = nufft_func( *points, @@ -629,5 +626,37 @@ def finufft_type3( targets: torch.Tensor, **finufftkwargs: Union[int, float], ) -> torch.Tensor: + """ + Evaluates the Type 3 (nonuniform-to-nonuniform) NUFFT on the inputs. + + This is a wrapper around :func:`finufft.nufft1d3`, :func:`finufft.nufft2d3`, and + :func:`finufft.nufft3d3` on CPU. + + Note that this function is **not implemented** for GPUs at the time of writing. + + Parameters + ---------- + points : torch.Tensor + DxM tensor of locations of the non-uniform source points. + Points must lie in the range ``[-3pi, 3pi]``. + strengths: torch.Tensor + Complex-valued tensor of source strengths at the non-uniform points. + All dimensions except the final dimension are treated as batch + dimensions. The final dimension must have size ``M``. + targets : torch.Tensor + DxN tensor of locations of the non-uniform target points. + **finufftkwargs : int | float + Additional keyword arguments are forwarded to the underlying + FINUFFT functions. A few notable options are + + - ``eps``: precision requested (default: ``1e-6``) + - ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``) + - ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``) + + Returns + ------- + torch.Tensor + A ``[batch]xN`` tensor of values at the target non-uniform points. + """ res: torch.Tensor = FinufftType3.apply(points, strengths, targets, finufftkwargs) return res diff --git a/tests/test_t3_forward.py b/tests/test_t3_forward.py new file mode 100644 index 0000000..323cdc7 --- /dev/null +++ b/tests/test_t3_forward.py @@ -0,0 +1,67 @@ +import numpy as np +import pytest +import torch + +import pytorch_finufft + + +def check_t3_forward(N: int, dim: int, device: str) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + slices = tuple(slice(None, N) for _ in range(dim)) + g = np.mgrid[slices] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(dim, -1)).to(device) + values = torch.randn(*g[0].shape, dtype=torch.complex128).to(device) + targets = points.clone() + + print("N is " + str(N)) + print("dim is " + str(dim)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + print("shape of targets is " + str(targets.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type3( + points, + values.flatten(), + targets, + ) + + against_torch = torch.fft.fftn(values.reshape(g[0].shape)) + + abs_errors = torch.abs(finufft_out - against_torch.flatten()) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 1.5e-5 * N**1.5 + assert l_2_error < 1e-5 * N**3 + assert l_1_error < 1e-5 * N**4.5 + + +Ns_and_dims = [ + (2, 1), + (3, 1), + (5, 1), + (10, 1), + (100, 1), + (101, 1), + (1000, 1), + (10001, 1), + (2, 2), + (3, 2), + (5, 2), + (10, 2), + (101, 2), + (2, 3), + (3, 3), + (5, 3), + (10, 3), + (37, 3), +] + + +@pytest.mark.parametrize("N, dim", Ns_and_dims) +def test_t3_forward_CPU(N, dim) -> None: + check_t3_forward(N, dim, "cpu") From 4f2be93adcef020bd7d89b45de0a6b412ddf741d Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 17 Jan 2024 14:37:29 -0500 Subject: [PATCH 3/6] Error and forward tests --- pytorch_finufft/checks.py | 4 +- pytorch_finufft/functional.py | 13 +-- tests/test_errors.py | 171 ++++++++++++++++++++++++++++++++++ tests/test_t3_forward.py | 15 ++- 4 files changed, 187 insertions(+), 16 deletions(-) diff --git a/pytorch_finufft/checks.py b/pytorch_finufft/checks.py index 7a34d3d..b1a20db 100644 --- a/pytorch_finufft/checks.py +++ b/pytorch_finufft/checks.py @@ -132,8 +132,8 @@ def check_sizes_t3( if targets_dim != points_dim: raise ValueError( - "The points tensor and targets tensor must be of the same dimension!" - + f"Got {points_dim=} and {targets_dim=} instead" + "Points and targets must be of the same dimension!" + + f" Got {points_dim=} and {targets_dim=} instead" ) if points_dim not in {1, 2, 3}: diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index e9a6b85..df3d262 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -450,8 +450,7 @@ class FinufftType3(torch.autograd.Function): FINUFFT problem type 3 """ - ISIGN_DEFAULT = 1 # note: FINUFFT default is 1 - MODEORD_DEFAULT = 1 # note: FINUFFT default is 0 + ISIGN_DEFAULT = -1 # note: FINUFFT default is 1 @staticmethod def setup_context( # type: ignore[override] @@ -471,7 +470,6 @@ def setup_context( # type: ignore[override] finufftkwargs = finufftkwargs.copy() ctx.save_for_backward(points, strengths, targets) ctx.isign = finufftkwargs.pop("isign", FinufftType3.ISIGN_DEFAULT) - ctx.mode_ordering = finufftkwargs.pop("modeord", FinufftType3.MODEORD_DEFAULT) ctx.finufftkwargs = finufftkwargs @staticmethod @@ -493,8 +491,6 @@ def forward( # type: ignore[override] finufftkwargs.setdefault("isign", FinufftType3.ISIGN_DEFAULT) - finufftkwargs.setdefault("modeord", FinufftType3.MODEORD_DEFAULT) - points = torch.atleast_2d(points) targets = torch.atleast_2d(targets) @@ -522,15 +518,14 @@ def backward( # type: ignore[override] ctx: Any, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: _i_sign = ctx.isign - _mode_ordering = ctx.mode_ordering - finufftkwargs = ctx.finufftkwargs + # finufftkwargs = ctx.finufftkwargs points, strengths, targets = ctx.saved_tensors points = torch.atleast_2d(points) targets = torch.atleast_2d(targets) - device = points.device - ndim = points.shape[0] + # device = points.device + # ndim = points.shape[0] grad_points = None grad_strengths = None diff --git a/tests/test_errors.py b/tests/test_errors.py index 69d128e..ad73d5d 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -46,6 +46,37 @@ def test_t2_mismatch_cuda_index() -> None: pytorch_finufft.functional.finufft_type2(points, targets) +def test_t3_mismatch_device_cuda_cpu() -> None: + points = torch.rand((2, 10), dtype=torch.float64) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((2, 10), dtype=torch.float64) + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type3(points.to("cuda:0"), values, targets) + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type3(points, values.to("cuda:0"), targets) + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type3(points, values, targets.to("cuda:0")) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="require multiple GPUs") +def test_t3_mismatch_cuda_index() -> None: + points = torch.rand((2, 10), dtype=torch.float64).to("cuda:0") + values = torch.randn(10, dtype=torch.complex128).to("cuda:0") + targets = torch.rand((2, 10), dtype=torch.float64).to("cuda:0") + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type3(points.to("cuda:1"), values, targets) + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type3(points, values.to("cuda:1"), targets) + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type3(points, values, targets.to("cuda:1")) + + # dtypes @@ -171,6 +202,101 @@ def test_t2_mismatch_precision() -> None: pytorch_finufft.functional.finufft_type2(points, targets) +def test_t3_non_complex_targets() -> None: + points = torch.rand((2, 10), dtype=torch.float64) + values = torch.randn(10, dtype=torch.float64) + targets = torch.rand((2, 10), dtype=torch.float64) + + with pytest.raises( + TypeError, + match="Strengths must have a dtype of torch.complex64 or torch.complex128", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_half_complex_targets() -> None: + points = torch.rand((2, 10), dtype=torch.float64) + targets = torch.rand((2, 10), dtype=torch.float64) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + values = torch.randn(10, dtype=torch.complex32) + + with pytest.raises( + TypeError, + match="Strengths must have a dtype of torch.complex64 or torch.complex128", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_non_real_points() -> None: + points = torch.rand((2, 10), dtype=torch.complex128) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((2, 10), dtype=torch.float64) + + with pytest.raises( + TypeError, + match="Points must have a dtype of torch.float64 as strengths has " + "a dtype of torch.complex128", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_non_real_targets() -> None: + points = torch.rand((2, 10), dtype=torch.float64) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((2, 10), dtype=torch.complex128) + + with pytest.raises( + TypeError, + match="Targets must have a dtype of torch.float64 as strengths has " + "a dtype of torch.complex128", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_mismatch_precision() -> None: + points = torch.rand((2, 10), dtype=torch.float32) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((2, 10), dtype=torch.float64) + + with pytest.raises( + TypeError, + match="Points must have a dtype of torch.float64 as strengths has " + "a dtype of torch.complex128", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + points = points.to(torch.float64) + targets = targets.to(torch.float32) + + with pytest.raises( + TypeError, + match="Targets must have a dtype of torch.float64 as strengths has " + "a dtype of torch.complex128", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + values = values.to(torch.complex64) + + with pytest.raises( + TypeError, + match="Points must have a dtype of torch.float32 as strengths has " + "a dtype of torch.complex64", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + points = points.to(torch.float32) + targets = targets.to(torch.float64) + + with pytest.raises( + TypeError, + match="Targets must have a dtype of torch.float32 as strengths has " + "a dtype of torch.complex64", + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + # sizes @@ -272,6 +398,51 @@ def test_t2_mismatch_dims() -> None: pytorch_finufft.functional.finufft_type2(points, targets) +def test_t3_points_targets_4d() -> None: + points = torch.rand((4, 10), dtype=torch.float64) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((4, 10), dtype=torch.float64) + + with pytest.raises(ValueError, match="Points and targets can be at most 3d, got"): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_points_targets_mismatch_dims() -> None: + points = torch.rand((2, 10), dtype=torch.float64) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((3, 10), dtype=torch.float64) + + with pytest.raises( + ValueError, match="Points and targets must be of the same dimension" + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_too_many_dims() -> None: + points = torch.rand((1, 4, 10), dtype=torch.float64) + values = torch.randn(10, dtype=torch.complex128) + targets = torch.rand((1, 4, 10), dtype=torch.float64) + + with pytest.raises(ValueError, match="The points tensor must be 1d or 2d"): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + points = torch.rand((2, 10), dtype=torch.float64) + + with pytest.raises(ValueError, match="The targets tensor must be 1d or 2d"): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + +def test_t3_mismatch_dims() -> None: + points = torch.rand((2, 10), dtype=torch.float64) + values = torch.randn(11, dtype=torch.complex128) + targets = torch.rand((2, 12), dtype=torch.float64) + + with pytest.raises( + ValueError, match="The same number of points and strengths must be supplied" + ): + pytorch_finufft.functional.finufft_type3(points, values, targets) + + # dependencies def test_finufft_not_installed(): if not pytorch_finufft.functional.CUFINUFFT_AVAIL: diff --git a/tests/test_t3_forward.py b/tests/test_t3_forward.py index 323cdc7..3e606b2 100644 --- a/tests/test_t3_forward.py +++ b/tests/test_t3_forward.py @@ -10,11 +10,16 @@ def check_t3_forward(N: int, dim: int, device: str) -> None: Tests against implementations of the FFT by setting up a uniform grid over which to call FINUFFT through the API. """ + slices = tuple(slice(None, N) for _ in range(dim)) g = np.mgrid[slices] * 2 * np.pi / N points = torch.from_numpy(g.reshape(dim, -1)).to(device) values = torch.randn(*g[0].shape, dtype=torch.complex128).to(device) - targets = points.clone() + targets = ( + torch.from_numpy(np.mgrid[slices].astype(np.float64)) + .reshape(dim, -1) + .to(device) + ) print("N is " + str(N)) print("dim is " + str(dim)) @@ -23,18 +28,18 @@ def check_t3_forward(N: int, dim: int, device: str) -> None: print("shape of targets is " + str(targets.shape)) finufft_out = pytorch_finufft.functional.finufft_type3( - points, - values.flatten(), - targets, + points, values.flatten(), targets, eps=1e-9 ) - against_torch = torch.fft.fftn(values.reshape(g[0].shape)) + against_torch = torch.fft.fftn(values) abs_errors = torch.abs(finufft_out - against_torch.flatten()) l_inf_error = abs_errors.max() l_2_error = torch.sqrt(torch.sum(abs_errors**2)) l_1_error = torch.sum(abs_errors) + # TODO consider updating tolerances below + # in conjunction with `eps` parameter above assert l_inf_error < 1.5e-5 * N**1.5 assert l_2_error < 1e-5 * N**3 assert l_1_error < 1e-5 * N**4.5 From 307989391b5ed9ba65b1203cfb2a219d0aa05a6f Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 17 Jan 2024 15:01:13 -0500 Subject: [PATCH 4/6] Tighter tolerances --- tests/test_t3_forward.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_t3_forward.py b/tests/test_t3_forward.py index 3e606b2..9deb7df 100644 --- a/tests/test_t3_forward.py +++ b/tests/test_t3_forward.py @@ -28,7 +28,7 @@ def check_t3_forward(N: int, dim: int, device: str) -> None: print("shape of targets is " + str(targets.shape)) finufft_out = pytorch_finufft.functional.finufft_type3( - points, values.flatten(), targets, eps=1e-9 + points, values.flatten(), targets ) against_torch = torch.fft.fftn(values) @@ -38,11 +38,9 @@ def check_t3_forward(N: int, dim: int, device: str) -> None: l_2_error = torch.sqrt(torch.sum(abs_errors**2)) l_1_error = torch.sum(abs_errors) - # TODO consider updating tolerances below - # in conjunction with `eps` parameter above - assert l_inf_error < 1.5e-5 * N**1.5 - assert l_2_error < 1e-5 * N**3 - assert l_1_error < 1e-5 * N**4.5 + assert l_inf_error < 5e-5 * N**1.5 + assert l_2_error < 1.5e-5 * N**3.2 + assert l_1_error < 1.5e-5 * N**4.5 Ns_and_dims = [ From 71d14350fee7701c085f5c2513b227f860b69759 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Mon, 5 Feb 2024 11:12:36 -0500 Subject: [PATCH 5/6] Lint fix --- pytorch_finufft/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index df3d262..cae6886 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -453,7 +453,7 @@ class FinufftType3(torch.autograd.Function): ISIGN_DEFAULT = -1 # note: FINUFFT default is 1 @staticmethod - def setup_context( # type: ignore[override] + def setup_context( ctx: Any, inputs: Tuple[ torch.Tensor, From 10b455224f610f5408ee7859e6e9069ea9bb416b Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Mon, 5 Feb 2024 11:24:15 -0500 Subject: [PATCH 6/6] Fix forward test tolerance --- pytorch_finufft/functional.py | 2 +- tests/test_t3_forward.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index cae6886..7133fe9 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -501,7 +501,7 @@ def forward( # type: ignore[override] if points.device.type != "cpu": raise NotImplementedError("Type 3 is not currently implemented for GPU") - nufft_func = get_nufft_func(ndim, 3, points.device.type) + nufft_func = get_nufft_func(ndim, 3, points.device) finufft_out = nufft_func( *points, diff --git a/tests/test_t3_forward.py b/tests/test_t3_forward.py index 9deb7df..0b2142c 100644 --- a/tests/test_t3_forward.py +++ b/tests/test_t3_forward.py @@ -4,6 +4,8 @@ import pytorch_finufft +torch.manual_seed(45678) + def check_t3_forward(N: int, dim: int, device: str) -> None: """ @@ -28,7 +30,7 @@ def check_t3_forward(N: int, dim: int, device: str) -> None: print("shape of targets is " + str(targets.shape)) finufft_out = pytorch_finufft.functional.finufft_type3( - points, values.flatten(), targets + points, values.flatten(), targets, eps=1e-7 ) against_torch = torch.fft.fftn(values)