Skip to content
2 changes: 2 additions & 0 deletions firedrake/preconditioners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
from firedrake.preconditioners.hiptmair import * # noqa: F401
from firedrake.preconditioners.facet_split import * # noqa: F401
from firedrake.preconditioners.bddc import * # noqa: F401
from firedrake.preconditioners.fieldsplit_snes import * # noqa: F401
from firedrake.preconditioners.auxiliary_snes import * # noqa: F401
79 changes: 79 additions & 0 deletions firedrake/preconditioners/auxiliary_snes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from firedrake.preconditioners.base import SNESBase
from firedrake.petsc import PETSc
from firedrake.dmhooks import get_appctx, get_function_space

__all__ = ("AuxiliaryOperatorSNES",)


class AuxiliaryOperatorSNES(SNESBase):
prefix = "aux_"

@PETSc.Log.EventDecorator()
def initialize(self, snes):
from firedrake import ( # ImportError if this is at file level
NonlinearVariationalSolver,
NonlinearVariationalProblem,
Function, TestFunction, Cofunction)

ctx = get_appctx(snes.dm)
V = get_function_space(snes.dm).collapse()

appctx = ctx.appctx
fcp = appctx.get("form_compiler_parameters")

u = Function(V)
v = TestFunction(V)

F, bcs, self.u = self.form(snes, u, v)

self.b = Cofunction(V.dual())
F += self.b

prefix = snes.getOptionsPrefix() + self.prefix

self.solver = NonlinearVariationalSolver(
NonlinearVariationalProblem(
F, self.u, bcs=bcs,
form_compiler_parameters=fcp),
appctx=appctx, options_prefix=prefix)
outer_snes = snes
inner_snes = self.solver.snes
inner_snes.incrementTabLevel(1, parent=outer_snes)
inner_snes.ksp.incrementTabLevel(1, parent=outer_snes)
inner_snes.ksp.pc.incrementTabLevel(1, parent=outer_snes)

def update(self, snes):
pass

@PETSc.Log.EventDecorator()
def step(self, snes, x, f, y):
from firedrake import errornorm
with self.u.dat.vec_wo as vec:
x.copy(vec)
# PETSc.Sys.Print(f"{x.norm() = }")
if f is not None:
with self.b.dat.vec_wo as vec:
f.copy(vec)
else:
self.b.zero()
# self.b.zero()

# PETSc.Sys.Print(f"Before: {errornorm(self.un, self.u) = :.5e}")
PETSc.Sys.Print(f"Before: {errornorm(self.un1, self.u) = :.5e}")
self.solver.solve()
# PETSc.Sys.Print(f"After: {errornorm(self.un, self.u) = :.5e}")
PETSc.Sys.Print(f"After: {errornorm(self.un1, self.u) = :.5e}")
with self.u.dat.vec_ro as vec:
# PETSc.Sys.Print(f"{vec.norm() = }")
vec.copy(y)
y.aypx(-1, x)
# PETSc.Sys.Print(f"{y.norm() = }")

def form(self, snes, u, v):
raise NotImplementedError

def view(self, snes, viewer=None):
super().view(snes, viewer)
if hasattr(self, "solver"):
viewer.printfASCII("SNES to apply auxiliary inverse\n")
self.solver.snes.view(viewer)
85 changes: 85 additions & 0 deletions firedrake/preconditioners/fieldsplit_snes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from firedrake.preconditioners.base import SNESBase
from firedrake.petsc import PETSc
from firedrake.dmhooks import get_appctx, get_function_space
from firedrake.function import Function

__all__ = ("FieldsplitSNES",)


class FieldsplitSNES(SNESBase):
prefix = "fieldsplit_"

# TODO:
# - Allow setting field grouping/ordering like fieldsplit

@PETSc.Log.EventDecorator()
def initialize(self, snes):
from firedrake.variational_solver import NonlinearVariationalSolver # ImportError if we do this at file level
ctx = get_appctx(snes.dm)
W = get_function_space(snes.dm)
self.sol = ctx._problem.u_restrict

# buffer to save solution to outer problem during solve
self.sol_outer = Function(self.sol.function_space())

# buffers for shuffling solutions during solve
self.sol_current = Function(self.sol.function_space())
self.sol_new = Function(self.sol.function_space())

# options for setting up the fieldsplit
snes_prefix = snes.getOptionsPrefix() + 'snes_' + self.prefix
# options for each field
sub_prefix = snes.getOptionsPrefix() + self.prefix

snes_options = PETSc.Options(snes_prefix)
self.fieldsplit_type = snes_options.getString('type', 'additive')
if self.fieldsplit_type not in ('additive', 'multiplicative'):
raise ValueError(
'FieldsplitSNES option snes_fieldsplit_type must be'
' "additive" or "multiplicative"')

split_ctxs = ctx.split([(i,) for i in range(len(W))])

self.solvers = tuple(
NonlinearVariationalSolver(
ctx._problem, appctx=ctx.appctx,
options_prefix=sub_prefix+str(i))
for i, ctx in enumerate(split_ctxs)
)

def update(self, snes):
pass

@PETSc.Log.EventDecorator()
def step(self, snes, x, f, y):
# store current value of outer solution
self.sol_outer.assign(self.sol)

# the full form in ctx now has the most up to date solution
with self.sol_current.dat.vec_wo as vec:
x.copy(vec)
self.sol.assign(self.sol_current)

# The current snes solution x is held in sol_current, and we
# will place the new solution in sol_new.
# The solvers evaluate forms containing sol, so for each
# splitting type sol needs to hold:
# - additive: all fields need to hold sol_current values
# - multiplicative: fields need to hold sol_current before
# they are are solved for, and keep the updated sol_new
# values afterwards.
for solver, u, ucurr, unew in zip(self.solvers,
self.sol.subfunctions,
self.sol_current.subfunctions,
self.sol_new.subfunctions):
solver.solve()
unew.assign(u)
if self.fieldsplit_type == 'additive':
u.assign(ucurr)

with self.sol_new.dat.vec_ro as vec:
vec.copy(y)
y.aypx(-1, x)

# restore outer solution
self.sol.assign(self.sol_outer)
200 changes: 200 additions & 0 deletions tests/firedrake/regression/test_fieldsplit_snes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from firedrake import *


def test_fieldsplit_snes():
re = Constant(100)
nu = Constant(1/re)

nx = 50
dt = Constant(0.1) # CFL = dt*nx

mesh = PeriodicUnitIntervalMesh(nx)
x, = SpatialCoordinate(mesh)

Vu = VectorFunctionSpace(mesh, "CG", 2)
Vq = FunctionSpace(mesh, "DG", 1)
W = Vu*Vq

w0 = Function(W)
u0, q0 = w0.subfunctions
u0.project(as_vector([0.5 + 1.0*sin(2*pi*x)]))
q0.interpolate(cos(2*pi*x))

def M(u, v):
return inner(u, v)*dx

def Aburgers(u, v, nu):
return (
inner(dot(u, nabla_grad(u)), v)*dx
+ nu*inner(grad(u), grad(v))*dx
)

def Ascalar(q, p, u):
n = FacetNormal(mesh)
un = 0.5*(dot(u, n) + abs(dot(u, n)))
return (- q*div(u*p)*dx
+ jump(un*q)*jump(p)*dS)

# current and next timestep
w = Function(W)
wn = Function(W)

u, q = split(w)
un, qn = split(wn)

v, p = TestFunctions(W)

# Trapezium rule
F = (
M(un - u, v) + 0.5*dt*(Aburgers(un, v, nu) + Aburgers(u, v, nu))
+ M(qn - q, p) + 0.5*dt*(Ascalar(qn, p, un) + Ascalar(q, p, u))
)

common_params = {
'snes_converged_reason': None,
'snes_monitor': None,
'snes_rtol': 1e-8,
'snes_atol': 1e-12,
'ksp_converged_reason': None,
'ksp_monitor': None,
}

newton_params = {
'snes_type': 'newtonls',
'mat_type': 'aij',
'ksp_type': 'preonly',
'pc_type': 'lu',
}

uparams = common_params | newton_params
qparams = common_params | newton_params | {'snes_type': 'ksponly'}

python_params = {
'snes_type': 'nrichardson',
'npc_snes_type': 'python',
'npc_snes_python_type': 'firedrake.FieldsplitSNES',
'npc_snes_fieldsplit_type': 'additive',
'npc_fieldsplit_0': uparams,
'npc_fieldsplit_1': qparams,
}

params = common_params | python_params

w.assign(w0)
wn.assign(w0)
u, q = w.subfunctions
un, qn = wn.subfunctions
solver = NonlinearVariationalSolver(
NonlinearVariationalProblem(F, wn),
solver_parameters=params,
options_prefix="")

nsteps = 2
for i in range(nsteps):
w.assign(wn)
solver.solve()


def M(u, v):
return inner(u, v)*dx


def A(u, v, nu):
return (
inner(dot(u, nabla_grad(u)), v)*dx
+ nu*inner(grad(u), grad(v))*dx
)

class AuxiliaryBurgersSNES(AuxiliaryOperatorSNES):

Check failure on line 108 in tests/firedrake/regression/test_fieldsplit_snes.py

View workflow job for this annotation

GitHub Actions / test / Lint codebase

E302

tests/firedrake/regression/test_fieldsplit_snes.py:108:1: E302 expected 2 blank lines, found 1

Check failure on line 108 in tests/firedrake/regression/test_fieldsplit_snes.py

View workflow job for this annotation

GitHub Actions / test / Lint codebase

E302

tests/firedrake/regression/test_fieldsplit_snes.py:108:1: E302 expected 2 blank lines, found 1
def form(self, snes, u, v):
appctx = self.get_appctx(snes)
nu = appctx["nu"]
dt = appctx["dt"]
un = appctx["un"]
un1 = appctx["un1"]
uh = (u + un)/2
F = M(u - un, v) + dt*A(uh, v, nu)
self.un = un
self.un1 = un1
return F, None, u


def test_auxiliary_snes():
re = Constant(100)
re_aux = Constant(50)

nu = Constant(1/re)
nu_aux = Constant(1/re_aux)

nx = 50
dt = Constant(0.1) # CFL = dt*nx

mesh = PeriodicUnitIntervalMesh(nx)
x, = SpatialCoordinate(mesh)

V = VectorFunctionSpace(mesh, "CG", 2)

# current and next timestep
ic = as_vector([1.0 + 0.5*sin(2*pi*x)])
un = Function(V).project(ic)
un1 = Function(V).project(ic)

v = TestFunction(V)

# Implicit midpoint rule
uh = (un + un1)/2
F = M(un1 - un, v) + dt*A(uh, v, nu)

solver_parameters = {
'snes': {
'view': ':snes_view.log',
'converged_reason': None,
'monitor': None,
'rtol': 1e-8,
'atol': 0,
'max_it': 3,
'convergence_test': 'skip',
'linesearch_type': 'l2',
'linesearch_damping': 1.0,
'linesearch_monitor': None,
},
'snes_type': 'nrichardson',
'npc_snes_type': 'python',
'npc_snes_python_type': f'{__name__}.AuxiliaryBurgersSNES',
'npc_aux': {
'snes': {
'converged_reason': None,
'monitor': None,
'rtol': 1e-4,
'atol': 0,
'max_it': 2,
'convergence_test': 'skip',
},
'snes_type': 'newtonls',
'mat_type': 'aij',
'ksp_type': 'preonly',
'pc_type': 'lu',
'pc_factor_mat_solver_type': 'petsc',
},
}

appctx = {
"nu": nu_aux,
"dt": dt,
"un": un,
"un1": un1,
}

solver = NonlinearVariationalSolver(
NonlinearVariationalProblem(F, un1),
solver_parameters=solver_parameters,
options_prefix="fd", appctx=appctx)

nsteps = 1
for i in range(nsteps):
solver.solve()
un.assign(un1)


if __name__ == "__main__":
test_auxiliary_snes()
Loading