Skip to content

Commit ed3f5b9

Browse files
committed
Interpolator: bugfix for reusable matfree adjoint Interpolator in parallel
1 parent 302b76d commit ed3f5b9

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

firedrake/interpolation.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -931,9 +931,6 @@ def callable():
931931
else:
932932
loops = []
933933

934-
if access == op2.INC:
935-
loops.append(tensor.zero)
936-
937934
# Arguments in the operand are allowed to be from a MixedFunctionSpace
938935
# We need to split the target space V and generate separate kernels
939936
if len(arguments) == 2:
@@ -961,12 +958,23 @@ def callable():
961958
if bcs and rank == 1:
962959
loops.extend(partial(bc.apply, f) for bc in bcs)
963960

964-
def callable(loops, f):
965-
for l in loops:
966-
l()
961+
def callable(loops, f, access):
962+
if access is op2.WRITE:
963+
for l in loops:
964+
l()
965+
return f
966+
# We are repeatedly incrementing into the same Dat so intermediate halo exchanges
967+
# can be skipped.
968+
f.dat.local_to_global_begin(access)
969+
with f.dat.frozen_halo(access):
970+
if access is op2.INC:
971+
f.dat.zero()
972+
for l in loops:
973+
l()
974+
f.dat.local_to_global_end(access)
967975
return f
968976

969-
return partial(callable, loops, f)
977+
return partial(callable, loops, f, access)
970978

971979

972980
@utils.known_pyop2_safe
@@ -1076,7 +1084,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
10761084
coefficient_numbers = kernel.coefficient_numbers
10771085
needs_external_coords = kernel.needs_external_coords
10781086
name = kernel.name
1079-
kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=True,
1087+
kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=(access is op2.WRITE),
10801088
flop_count=kernel.flop_count, events=(kernel.event,))
10811089

10821090
parloop_args = [kernel, cell_set]
@@ -1099,7 +1107,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
10991107
if isinstance(tensor, op2.Global):
11001108
parloop_args.append(tensor(access))
11011109
elif isinstance(tensor, op2.Dat):
1102-
V_dest = arguments[-1].function_space() if isinstance(dual_arg, ufl.Cofunction) else V
1110+
V_dest = arguments[0].function_space()
11031111
m_ = get_interp_node_map(source_mesh, target_mesh, V_dest)
11041112
parloop_args.append(tensor(access, m_))
11051113
else:
@@ -1159,11 +1167,10 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
11591167
parloop_args.append(target_ref_coords.dat(op2.READ, m_))
11601168

11611169
parloop = op2.ParLoop(*parloop_args)
1162-
parloop_compute_callable = parloop.compute
11631170
if isinstance(tensor, op2.Mat):
1164-
return parloop_compute_callable, tensor.assemble
1171+
return parloop, tensor.assemble
11651172
else:
1166-
return copyin + callables + (parloop_compute_callable, ) + copyout
1173+
return copyin + callables + (parloop, ) + copyout
11671174

11681175

11691176
def get_interp_node_map(source_mesh, target_mesh, fs):

tests/firedrake/regression/test_interpolate.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def test_trace():
327327
assert np.allclose(x_tr_cg.dat.data, x_tr_dir.dat.data)
328328

329329

330+
@pytest.mark.parallel(nprocs=[1, 3])
330331
@pytest.mark.parametrize("rank", (0, 1))
331332
@pytest.mark.parametrize("mat_type", ("matfree", "aij"))
332333
@pytest.mark.parametrize("degree", (1, 3))
@@ -566,3 +567,35 @@ def test_mixed_matrix(mode):
566567
result_explicit = assemble(action(a, u))
567568
for x, y in zip(result_explicit.subfunctions, result_matfree.subfunctions):
568569
assert np.allclose(x.dat.data, y.dat.data)
570+
571+
572+
@pytest.mark.parallel(nprocs=2)
573+
@pytest.mark.parametrize("mode", ["forward", "adjoint"])
574+
@pytest.mark.parametrize("family,degree", [("CG", 1)])
575+
def test_reuse_interpolate(family, degree, mode):
576+
mesh = UnitSquareMesh(1, 1)
577+
V = FunctionSpace(mesh, family, degree)
578+
rg = RandomGenerator(PCG64(seed=123456789))
579+
if mode == "forward":
580+
u = Function(V)
581+
expr = interpolate(u, V)
582+
583+
elif mode == "adjoint":
584+
u = Function(V.dual())
585+
expr = interpolate(TestFunction(V), u)
586+
587+
I = Interpolator(expr, V)
588+
589+
for k in range(2):
590+
u.assign(k+1)
591+
expected = u.dat.data.copy()
592+
result = I.assemble()
593+
594+
# Test that the input was not modified
595+
x = u.dat.data
596+
assert np.allclose(x, expected)
597+
598+
# Test for correctness
599+
y = result.dat.data
600+
assert np.allclose(y, expected)
601+
print("pass", k)

0 commit comments

Comments
 (0)