Skip to content

Commit 05fa39b

Browse files
committed
Fix assemble(slate.Tensor, diagonal=True)
1 parent 6e5c84a commit 05fa39b

File tree

3 files changed

+6
-13
lines changed

3 files changed

+6
-13
lines changed

firedrake/assemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,7 +1745,7 @@ def _as_global_kernel_arg_output(_, self):
17451745
if rank == 0:
17461746
return op2.GlobalKernelArg((1,))
17471747
elif rank == 1 or rank == 2 and self._diagonal:
1748-
V, = Vs
1748+
V = Vs[0]
17491749
if V.ufl_element().family() == "Real":
17501750
return op2.GlobalKernelArg((1,))
17511751
else:
@@ -2052,7 +2052,7 @@ def _as_parloop_arg_output(_, self):
20522052
if rank == 0:
20532053
return op2.GlobalParloopArg(self._tensor)
20542054
elif rank == 1 or rank == 2 and self._diagonal:
2055-
V, = Vs
2055+
V = Vs[0]
20562056
if V.ufl_element().family() == "Real":
20572057
return op2.GlobalParloopArg(self._tensor)
20582058
else:

firedrake/slate/slate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,8 +1387,7 @@ def arg_function_spaces(self):
13871387
"""Returns a tuple of function spaces that the tensor
13881388
is defined on.
13891389
"""
1390-
tensor, = self.operands
1391-
return tuple(arg.function_space() for arg in tensor.arguments())
1390+
return tuple(arg.function_space() for arg in self.arguments())
13921391

13931392
def arguments(self):
13941393
"""Returns a tuple of arguments associated with the tensor."""

tests/firedrake/slate/test_linear_algebra.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,8 @@ def test_inverse_action(mat_type, rhs_type):
152152
assert np.allclose(x.dat.data, f.dat.data, rtol=1.e-13)
153153

154154

155-
@pytest.mark.parametrize("mat_type, rhs_type", [
156-
("slate", "slate"), ("slate", "form"), ("slate", "cofunction"),
157-
("aij", "cofunction"), ("aij", "form"),
158-
("matfree", "cofunction"), ("matfree", "form")])
155+
@pytest.mark.parametrize("rhs_type", ["slate", "form", "cofunction"])
156+
@pytest.mark.parametrize("mat_type", ["slate", "aij", "matfree"])
159157
def test_solve_interface(mat_type, rhs_type):
160158
mesh = UnitSquareMesh(1, 1)
161159
V = FunctionSpace(mesh, "HDivT", 0)
@@ -180,12 +178,8 @@ def test_solve_interface(mat_type, rhs_type):
180178
else:
181179
raise ValueError("Invalid rhs type")
182180

183-
sp = None
184-
if mat_type == "matfree":
185-
sp = {"pc_type": "none"}
186-
187181
x = Function(V)
188182
problem = LinearVariationalProblem(A, b, x, bcs=bcs)
189-
solver = LinearVariationalSolver(problem, solver_parameters=sp)
183+
solver = LinearVariationalSolver(problem)
190184
solver.solve()
191185
assert np.allclose(x.dat.data, f.dat.data, rtol=1.e-13)

0 commit comments

Comments
 (0)