Skip to content

Commit 10b4552

Browse files
committed
Fix forward test tolerance
1 parent 71d1435 commit 10b4552

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

pytorch_finufft/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def forward( # type: ignore[override]
501501
if points.device.type != "cpu":
502502
raise NotImplementedError("Type 3 is not currently implemented for GPU")
503503

504-
nufft_func = get_nufft_func(ndim, 3, points.device.type)
504+
nufft_func = get_nufft_func(ndim, 3, points.device)
505505

506506
finufft_out = nufft_func(
507507
*points,

tests/test_t3_forward.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import pytorch_finufft
66

7+
torch.manual_seed(45678)
8+
79

810
def check_t3_forward(N: int, dim: int, device: str) -> None:
911
"""
@@ -28,7 +30,7 @@ def check_t3_forward(N: int, dim: int, device: str) -> None:
2830
print("shape of targets is " + str(targets.shape))
2931

3032
finufft_out = pytorch_finufft.functional.finufft_type3(
31-
points, values.flatten(), targets
33+
points, values.flatten(), targets, eps=1e-7
3234
)
3335

3436
against_torch = torch.fft.fftn(values)

0 commit comments

Comments
 (0)