Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 25 additions & 27 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,8 +930,8 @@ def callable():
return callable
else:
loops = []

if access == op2.INC:
# Initialise to zero if needed
if access is op2.INC:
loops.append(tensor.zero)

# Arguments in the operand are allowed to be from a MixedFunctionSpace
Expand All @@ -957,7 +957,7 @@ def callable():
for indices, sub_expr in expressions.items():
sub_tensor = tensor[indices[0]] if rank == 1 else tensor
loops.extend(_interpolator(sub_tensor, sub_expr, subset, access, bcs=bcs))

# Apply bcs
if bcs and rank == 1:
loops.extend(partial(bc.apply, f) for bc in bcs)

Expand Down Expand Up @@ -1038,32 +1038,36 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
parameters = {}
parameters['scalar_type'] = utils.ScalarType

callables = ()
copyin = ()
copyout = ()

# For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple
# contributions from the facet DOFs of the dual argument.
# The incoming Cofunction needs to be weighted by the reciprocal of the DOF multiplicity.
needs_weight = isinstance(dual_arg, ufl.Cofunction) and not to_element.is_dg()
if needs_weight:
# Compute the reciprocal of the DOF multiplicity
# Create a buffer for the weighted Cofunction
W = dual_arg.function_space()
v = firedrake.Function(W)
expr = expr._ufl_expr_reconstruct_(operand, v=v)
copyin += (partial(dual_arg.dat.copy, v.dat),)

# Compute the reciprocal of the DOF multiplicity
wdat = W.make_dat()
m_ = get_interp_node_map(source_mesh, target_mesh, W)
wsize = W.finat_element.space_dimension() * W.block_size
kernel_code = f"""
void multiplicity(PetscScalar *restrict w) {{
for (PetscInt i=0; i<{wsize}; i++) w[i] += 1;
}}"""
kernel = op2.Kernel(kernel_code, "multiplicity", requires_zeroed_output_arguments=False)
weight = firedrake.Function(W)
m_ = get_interp_node_map(source_mesh, target_mesh, W)
op2.par_loop(kernel, cell_set, weight.dat(op2.INC, m_))
with weight.dat.vec as w:
kernel = op2.Kernel(kernel_code, "multiplicity")
op2.par_loop(kernel, cell_set, wdat(op2.INC, m_))
with wdat.vec as w:
w.reciprocal()

# Create a buffer for the weighted Cofunction and a callable to apply the weight
v = firedrake.Function(W)
expr = expr._ufl_expr_reconstruct_(operand, v=v)
with weight.dat.vec_ro as w, dual_arg.dat.vec_ro as x, v.dat.vec_wo as y:
callables += (partial(y.pointwiseMult, x, w),)
# Create a callable to apply the weight
with wdat.vec_ro as w, v.dat.vec as y:
copyin += (partial(y.pointwiseMult, y, w),)

# We need to pass both the ufl element and the finat element
# because the finat elements might not have the right mapping
Expand All @@ -1079,7 +1083,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
coefficient_numbers = kernel.coefficient_numbers
needs_external_coords = kernel.needs_external_coords
name = kernel.name
kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=True,
kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=(access is not op2.INC),
flop_count=kernel.flop_count, events=(kernel.event,))

parloop_args = [kernel, cell_set]
Expand All @@ -1092,17 +1096,12 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
output = tensor
tensor = op2.Dat(tensor.dataset)
if access is not op2.WRITE:
copyin = (partial(output.copy, tensor), )
else:
copyin = ()
copyout = (partial(tensor.copy, output), )
else:
copyin = ()
copyout = ()
copyin += (partial(output.copy, tensor), )
copyout += (partial(tensor.copy, output), )
if isinstance(tensor, op2.Global):
parloop_args.append(tensor(access))
elif isinstance(tensor, op2.Dat):
V_dest = arguments[-1].function_space() if isinstance(dual_arg, ufl.Cofunction) else V
V_dest = arguments[-1].function_space()
m_ = get_interp_node_map(source_mesh, target_mesh, V_dest)
parloop_args.append(tensor(access, m_))
else:
Expand Down Expand Up @@ -1162,11 +1161,10 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
parloop_args.append(target_ref_coords.dat(op2.READ, m_))

parloop = op2.ParLoop(*parloop_args)
parloop_compute_callable = parloop.compute
if isinstance(tensor, op2.Mat):
return parloop_compute_callable, tensor.assemble
return parloop, tensor.assemble
else:
return copyin + callables + (parloop_compute_callable, ) + copyout
return copyin + (parloop, ) + copyout


def get_interp_node_map(source_mesh, target_mesh, fs):
Expand Down
33 changes: 33 additions & 0 deletions tests/firedrake/regression/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def test_trace():
assert np.allclose(x_tr_cg.dat.data, x_tr_dir.dat.data)


@pytest.mark.parallel([1, 3])
@pytest.mark.parametrize("rank", (0, 1))
@pytest.mark.parametrize("mat_type", ("matfree", "aij"))
@pytest.mark.parametrize("degree", (1, 3))
Expand Down Expand Up @@ -566,3 +567,35 @@ def test_mixed_matrix(mode):
result_explicit = assemble(action(a, u))
for x, y in zip(result_explicit.subfunctions, result_matfree.subfunctions):
assert np.allclose(x.dat.data, y.dat.data)


@pytest.mark.parallel(2)
@pytest.mark.parametrize("mode", ["forward", "adjoint"])
@pytest.mark.parametrize("family,degree", [("CG", 1), ("DG", 0)])
def test_interpolator_reuse(family, degree, mode):
mesh = UnitSquareMesh(1, 1)
V = FunctionSpace(mesh, family, degree)
rg = RandomGenerator(PCG64(seed=123456789))
if mode == "forward":
u = Function(V)
expr = interpolate(u, V)

elif mode == "adjoint":
u = Function(V.dual())
expr = interpolate(TestFunction(V), u)

I = Interpolator(expr, V)

for k in range(3):
u.assign(rg.uniform(u.function_space()))
expected = u.dat.data.copy()

tensor = Function(expr.function_space())
result = I.assemble(tensor=tensor)
assert result is tensor

# Test that the input was not modified
assert np.allclose(u.dat.data, expected)

# Test for correctness
assert np.allclose(result.dat.data, expected)