From e164b4170a64d3fba03e1e6e6548501a26674adf Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 1 Jul 2025 07:09:34 +0100 Subject: [PATCH 01/19] OptimizationProblem should assume AbstractReducedFunctional not ReducedFunctional --- pyadjoint/optimization/optimization_problem.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyadjoint/optimization/optimization_problem.py b/pyadjoint/optimization/optimization_problem.py index 18f4d59b..fe0de5f3 100644 --- a/pyadjoint/optimization/optimization_problem.py +++ b/pyadjoint/optimization/optimization_problem.py @@ -2,7 +2,7 @@ from .constraints import Constraint, canonicalise from ..overloaded_type import OverloadedType, create_overloaded_object -from ..reduced_functional import ReducedFunctional +from ..reduced_functional import AbstractReducedFunctional __all__ = ['MinimizationProblem', 'MaximizationProblem'] @@ -38,7 +38,7 @@ def __check_arguments(self, reduced_functional, bounds, constraints): if type(self) is OptimizationProblem: raise TypeError("Instantiate a MinimizationProblem or MaximizationProblem.") - if not isinstance(reduced_functional, ReducedFunctional): + if not isinstance(reduced_functional, AbstractReducedFunctional): raise TypeError("reduced_functional should be a ReducedFunctional") if bounds is not None: From 3a9e3e0cb01b02fabbd4e775debea030c92aee82 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 1 Jul 2025 13:24:03 +0100 Subject: [PATCH 02/19] WIP: PETSc python ReducedFunctionalMat for TAOSolver --- pyadjoint/optimization/tao_solver.py | 418 ++++++++++++++++++++------- 1 file changed, 315 insertions(+), 103 deletions(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 47b6babd..63fcbd7c 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -1,6 +1,7 @@ from contextlib import contextmanager -from functools import cached_property +from functools import wraps import itertools +from enum import Enum from numbers import Complex import warnings @@ -127,6 +128,22 @@ def to_petsc(self, x, y): x_sub.restoreSubVector(iset, x_sub) +def new_control_variable(reduced_functional, dual=False): + """Return new variables suitable for storing a control value or its dual. + + Args: + reduced_functional (ReducedFunctional): The reduced functional whose + controls are to be copied. + dual (bool): whether to return a dual type. If False then a primal type is returned. + + Returns: + tuple[OverloadedType]: New variables suitable for storing a control value. + """ + + return tuple(control._ad_init_zero(dual=dual) + for control in reduced_functional.controls) + + # Modified version of flatten_parameters function from firedrake/petsc.py, # Firedrake master branch 57e21cc8ebdb044c1d8423b48f3dbf70975d5548, first # added 2024-08-08 @@ -324,9 +341,225 @@ def inserted_options(self): del self.options_object[self.options_prefix + k] +def get_valid_comm(comm): + """ + Return a valid communicator from a user provided (possibly null) comm. + + Args: + comm: Any[petsc4py.PETSc.Comm,mpi4py.MPI.Comm,None] + + Returns: + mpi4py.MPI.Comm. COMM_WORLD if `comm is None`, otherwise `comm.tompi4py()`. + """ + if comm is None: + comm = PETSc.COMM_WORLD + if hasattr(comm, "tompi4py"): + comm = comm.tompi4py() + return comm + + +class RFAction(Enum): + """ + The type of linear action that a ReducedFunctionalMat should apply. + """ + TLM = 'tlm' + Adjoint = 'adjoint' + Hessian = 'hessian' + + +TLMAction = RFAction.TLM +AdjointAction = RFAction.Adjoint +HessianAction = RFAction.Hessian + + +def check_rf_action(action): + def check_rf_action_decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.action != action: + raise NotImplementedError( + f'Cannot apply {str(action)} action if {self.action=}') + return func(self, *args, **kwargs) + return wrapper + return check_rf_action_decorator + + +class ReducedFunctionalMatCtx: + """ + PETSc.Mat Python context to apply the action of a pyadjoint.ReducedFunctional. + + If V is the control space and U is the functional space, each action has the following map: + Jhat : V -> U + TLM : V -> U + Adjoint : U* -> V* + Hessian : V x U* -> V* | V -> V* + + Args: + rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. + action (RFAction): Whether to apply the TLM, adjoint, or Hessian action. + apply_riesz (bool): Whether to apply the riesz map before returning the + result of the action to PETSc. + appctx (Optional[dict]): User provided context. + comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. + """ + + def __init__(self, rf, action=HessianAction, *, + apply_riesz=False, appctx=None, comm=PETSc.COMM_WORLD): + comm = get_valid_comm(comm) + + self.rf = rf + self.appctx = appctx + self.control_interface = PETScVecInterface( + tuple(c.control for c in rf.controls), + comm=comm) + self.apply_riesz = apply_riesz + if action in (AdjointAction, TLMAction): + self.functional_interface = PETScVecInterface( + rf.functional, comm=comm) + + if action == HessianAction: # control -> control + self.xinterface = self.control_interface + self.yinterface = self.control_interface + + self.x = new_control_variable(rf) + self.mult_impl = self._mult_hessian + + elif action == AdjointAction: # functional -> control + self.xinterface = self.functional_interface + self.yinterface = self.control_interface + + self.x = rf.functional._ad_copy() + self.mult_impl = self._mult_adjoint + + elif action == TLMAction: # control -> functional + self.xinterface = self.control_interface + self.yinterface = self.functional_interface + + self.x = new_control_variable(rf) + self.mult_impl = self._mult_tlm + else: + raise ValueError( + 'Unrecognised {action = }.') + + self.action = action + self._m = new_control_variable(rf) + self._shift = 0 + + @classmethod + def update(cls, obj, x, A, P): + ctx = A.getPythonContext() + ctx.control_interface.from_petsc(x, ctx._m) + ctx.update_tape_values(update_adjoint=True) + ctx._shift = 0 + + def shift(self, A, alpha): + self._shift += alpha + + def update_tape_values(self, update_adjoint=True): + _ = self.rf(self._m) + if update_adjoint: + _ = self.rf.derivative(apply_riesz=False) + + def mult(self, A, x, y): + self.xinterface.from_petsc(x, self.x) + out = self.mult_impl(A, self.x) + self.yinterface.to_petsc(y, out) + + if self._shift != 0: + y.axpy(self._shift, x) + + @check_rf_action(action=HessianAction) + def _mult_hessian(self, A, x): + # self.update_tape_values(update_adjoint=True) + return self.rf.hessian( + x, apply_riesz=self.apply_riesz) + + @check_rf_action(TLMAction) + def _mult_tlm(self, A, x): + # self.update_tape_values(update_adjoint=False) + return self.rf.tlm(x) + + @check_rf_action(AdjointAction) + def _mult_adjoint(self, A, x): + # self.update_tape_values(update_adjoint=False) + return self.rf.derivative( + adj_input=x, apply_riesz=self.apply_riesz) + + +def ReducedFunctionalMat(rf, action=HessianAction, *, apply_riesz=False, appctx=None, comm=None): + """ + PETSc.Mat to apply the action of a pyadjoint.ReducedFunctional. + + If V is the control space and U is the functional space, each action has the following map: + Jhat : V -> U + TLM : V -> U + Adjoint : U* -> V* + Hessian : V x U* -> V* | V -> V* + + Args: + rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. + action (RFAction): Whether to apply the TLM, adjoint, or Hessian action. + apply_riesz (bool): Whether to apply the riesz map before returning the + result of the action to PETSc. + appctx (Optional[dict]): User provided context. + comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. + """ + ctx = ReducedFunctionalMatCtx( + rf, action, appctx=appctx, apply_riesz=apply_riesz, comm=comm) + + ncol = ctx.xinterface.n + Ncol = ctx.xinterface.N + + nrow = ctx.yinterface.n + Nrow = ctx.yinterface.N + + mat = PETSc.Mat().createPython( + ((nrow, Nrow), (ncol, Ncol)), + ctx, comm=ctx.control_interface.comm) + if action == HessianAction: + mat.setOption(PETSc.Mat.Option.SYMMETRIC, True) + mat.setUp() + mat.assemble() + return mat + + +class RieszMapMatCtx: + def __init__(self, controls, comm=None): + comm = get_valid_comm(comm) + + self.controls = Enlist(controls) + self.vec_interface = PETScVecInterface( + tuple(c.control for c in controls), + comm=comm) + + self.dJ = tuple(c._ad_init_zero(dual=True) + for c in self.controls) + + def mult(self, mat, x, y): + self.vec_interface.from_petsc(x, self.dJ) + dJ = tuple(c._ad_convert_riesz(dJi, riesz_map=c.riesz_map) + for c, dJi in zip(self.controls, self.dJ)) + self.vec_interface.to_petsc(y, dJ) + + +def RieszMapMat(controls, symmetric=True, comm=None): + ctx = RieszMapMatCtx(controls, comm=comm) + + n = ctx.vec_interface.n + N = ctx.vec_interface.N + + mat = PETSc.Mat().createPython( + ((n, N), (n, N)), ctx, + comm=ctx.vec_interface.comm) + if symmetric: + mat.setOption(PETSc.Mat.Option.SYMMETRIC, True) + mat.setUp() + mat.assemble() + return mat + + class TAOObjective: - """Utility class for computing functional values and associated - derivatives. + """Utility class for computing functional values and associated derivatives. Args: rf (ReducedFunctional): Defines the forward, and used to compute @@ -344,6 +577,37 @@ def reduced_functional(self): return self._reduced_functional + def objective(self, m): + """Evaluate the forward. + + Args: + m (OverloadedType or Sequence[OverloadedType]): Defines the control + value. + Returns: + AdjFloat: The value of the functional. + """ + + m = Enlist(m) + J = self.reduced_functional(tuple(m_i._ad_copy() for m_i in m)) + return J + + def gradient(self, m): + """Compute a first derivative. + + Args: + m (OverloadedType or Sequence[OverloadedType]): Defines the control + value. + Returns: + AdjFloat: The value of the functional. + OverloadedType or Sequence[OverloadedType]: The (dual space) + derivative. + """ + + m = Enlist(m) + # J = self.reduced_functional(tuple(m_i._ad_copy() for m_i in m)) + dJ = self.reduced_functional.derivative() + return m.delist(dJ) + def objective_gradient(self, m): """Evaluate the forward, and compute a first derivative. @@ -382,32 +646,6 @@ def hessian(self, m, m_dot): ddJ = self.reduced_functional.hessian(tuple(m_dot_i._ad_copy() for m_dot_i in m_dot)) return m.delist(ddJ) - def new_control_variable(self): - """Return new variables suitable for storing a control value. Not - initialized to zero. - - Returns: - tuple[OverloadedType]: New variables suitable for storing a control - value. - """ - - # Not initialized to zero - return tuple(m.control._ad_copy() - for m in self.reduced_functional.controls) - - def new_dual_control_variable(self): - """Return new variables suitable for storing a value for a (dual space) - derivative of the functional with respect to the control. - - Returns: - tuple[OverloadedType]: New variables suitable for storing a value - for a (dual space) derivative of the functional with respect to - the control. - """ - - return tuple(control._ad_init_zero(dual=True) - for control in self.reduced_functional.controls) - class TAOConvergenceError(Exception): """Raised if a TAO solve fails to converge. @@ -428,15 +666,17 @@ class TAOSolver(OptimizationSolver): """Use TAO to solve an optimization problem. Args: - problem (MinimizationProblem): Defines the optimization problem to be - solved. + problem (MinimizationProblem): Defines the optimization problem to be solved. parameters (Mapping): TAO options. + options_prefix (Optional[str]): prefix for the TAO solver. + appctx (Optional[dict]): User provided context. + Pmat (Optional petsc4py.PETSc.Mat): Hessian preconditioning matrix. comm (petsc4py.PETSc.Comm or mpi4py.MPI.Comm): Communicator. - convert_options (Mapping): Defines the `options` argument to - :meth:`OverloadedType._ad_convert_type`. """ - def __init__(self, problem, parameters, *, comm=None): + def __init__(self, problem, parameters, *, + options_prefix=None, appctx=None, + Pmat=None, comm=None): if PETSc is None: raise RuntimeError("PETSc not available") @@ -445,82 +685,51 @@ def __init__(self, problem, parameters, *, comm=None): if problem.constraints is not None: raise NotImplementedError("Constraints not implemented") - if comm is None: - comm = PETSc.COMM_WORLD - if hasattr(comm, "tompi4py"): - comm = comm.tompi4py() + comm = get_valid_comm(comm) - tao_objective = TAOObjective(problem.reduced_functional) + rf = problem.reduced_functional + tao_objective = TAOObjective(rf) vec_interface = PETScVecInterface( - tuple(control.control for control in tao_objective.reduced_functional.controls), - comm=comm) - n, N = vec_interface.n, vec_interface.N + tuple(control.control for control in rf.controls), comm=comm) + to_petsc, from_petsc = vec_interface.to_petsc, vec_interface.from_petsc tao = PETSc.TAO().create(comm=comm) + def objective(tao, x, g): + m = new_control_variable(rf) + from_petsc(x, m) + J_val = tao_objective.objective(m) + return J_val + + def gradient(tao, x, g): + m = new_control_variable(rf) + from_petsc(x, m) + dJ = tao_objective.gradient(m) + to_petsc(g, dJ) + def objective_gradient(tao, x, g): - m = tao_objective.new_control_variable() + m = new_control_variable(rf) from_petsc(x, m) J_val, dJ = tao_objective.objective_gradient(m) to_petsc(g, dJ) return J_val - tao.setObjectiveGradient(objective_gradient, None) - - def hessian(tao, x, H, P): - H.getPythonContext().set_control_variable(x) - - class Hessian: - """:class:`petsc4py.PETSc.Mat` context. - """ - - def __init__(self): - self._shift = 0.0 - - @cached_property - def _m(self): - return tao_objective.new_control_variable() - - def set_control_variable(self, x): - from_petsc(x, self._m) - self._shift = 0.0 - - def shift(self, A, alpha): - self._shift += alpha - - def mult(self, A, x, y): - m_dot = tao_objective.new_control_variable() - from_petsc(x, m_dot) - ddJ = tao_objective.hessian(self._m, m_dot) - to_petsc(y, ddJ) - if self._shift != 0.0: - y.axpy(self._shift, x) - - H_matrix = PETSc.Mat().createPython(((n, N), (n, N)), - Hessian(), comm=comm) - H_matrix.setOption(PETSc.Mat.Option.SYMMETRIC, True) - H_matrix.setUp() - tao.setHessian(hessian, H_matrix) - - class GradientNorm: - """:class:`petsc4py.PETSc.Mat` context. - """ - - def mult(self, A, x, y): - dJ = tao_objective.new_dual_control_variable() - from_petsc(x, dJ) - assert len(tao_objective.reduced_functional.controls) == len(dJ) - dJ = tuple(control._ad_convert_riesz(dJ_i, riesz_map=control.riesz_map) - for control, dJ_i in zip(tao_objective.reduced_functional.controls, dJ)) - to_petsc(y, dJ) - - M_inv_matrix = PETSc.Mat().createPython(((n, N), (n, N)), - GradientNorm(), comm=comm) - M_inv_matrix.setOption(PETSc.Mat.Option.SYMMETRIC, True) - M_inv_matrix.setUp() - tao.setGradientNorm(M_inv_matrix) + tao.setObjectiveGradient(objective_gradient) + tao.setObjective(objective) + tao.setGradient(gradient) + + hessian_mat = ReducedFunctionalMat( + problem.reduced_functional, appctx=appctx, + action=HessianAction, comm=comm) + + tao.setHessian( + hessian_mat.getPythonContext().update, + H=hessian_mat, P=Pmat or hessian_mat) + + Minv_mat = RieszMapMat(rf.controls, comm=comm) + tao.setGradientNorm(Minv_mat) if problem.bounds is not None: lbs = [] @@ -540,10 +749,12 @@ def mult(self, A, x, y): to_petsc(ub_vec, ubs) tao.setVariableBounds(lb_vec, ub_vec) - options = OptionsManager(parameters, None) - options.set_from_options(tao) + self.options = OptionsManager(parameters, options_prefix) + self.options.set_from_options(tao) if tao.getType() in {PETSc.TAO.Type.LMVM, PETSc.TAO.Type.BLMVM}: + n, N = vec_interface.n, vec_interface.N + class InitialHessian: """:class:`petsc4py.PETSc.Mat` context. """ @@ -553,7 +764,7 @@ class InitialHessianPreconditioner: """ def apply(self, pc, x, y): - dJ = tao_objective.new_dual_control_variable() + dJ = new_control_variable(rf, dual=True) from_petsc(x, dJ) assert len(tao_objective.reduced_functional.controls) == len(dJ) dJ = tuple(control._ad_convert_riesz(dJ_i, riesz_map=control.riesz_map) @@ -626,9 +837,10 @@ def solve(self): m = tuple( control.tape_value()._ad_copy() for control in self.tao_objective.reduced_functional.controls) - self._vec_interface.to_petsc(self.x, m) - self.tao.solve() - self._vec_interface.from_petsc(self.x, m) + with self.options.inserted_options(): + self._vec_interface.to_petsc(self.x, m) + self.tao.solve() + self._vec_interface.from_petsc(self.x, m) if self.tao.getConvergedReason() <= 0: # Using the same format as Firedrake linear solver errors raise TAOConvergenceError( From 225e01996604c73e2b51d26b65fb90b77680ebf7 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 14 Oct 2025 14:45:02 +0100 Subject: [PATCH 03/19] tidy up petsc python RF mat --- pyadjoint/optimization/tao_solver.py | 113 ++++++++++++++++----------- 1 file changed, 66 insertions(+), 47 deletions(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index fd77df17..51665e65 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -165,14 +165,16 @@ class RFAction(Enum): """ The type of linear action that a ReducedFunctionalMat should apply. """ + FORWARD = 'forward' TLM = 'tlm' - Adjoint = 'adjoint' - Hessian = 'hessian' + ADJOINT = 'adjoint' + HESSIAN = 'hessian' -TLMAction = RFAction.TLM -AdjointAction = RFAction.Adjoint -HessianAction = RFAction.Hessian +FORWARD = RFAction.FORWARD +TLM = RFAction.TLM +ADJOINT = RFAction.ADJOINT +HESSIAN = RFAction.HESSIAN def check_rf_action(action): @@ -203,11 +205,16 @@ class ReducedFunctionalMatCtx: apply_riesz (bool): Whether to apply the riesz map before returning the result of the action to PETSc. appctx (Optional[dict]): User provided context. + always_update_tape (bool): Whether to force reevaluation of the forward model every time + `mult` is called. If action is HESSIAN then this will also force the adjoint model to + be reevaluated at every call to `mult`. comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. """ - def __init__(self, rf, action=HessianAction, *, - apply_riesz=False, appctx=None, comm=PETSc.COMM_WORLD): + def __init__(self, rf, action=HESSIAN, *, + apply_riesz=False, appctx=None, + always_update_tape=False, + comm=PETSc.COMM_WORLD): comm = get_valid_comm(comm) self.rf = rf @@ -216,25 +223,25 @@ def __init__(self, rf, action=HessianAction, *, tuple(c.control for c in rf.controls), comm=comm) self.apply_riesz = apply_riesz - if action in (AdjointAction, TLMAction): + if action in (ADJOINT, TLM): self.functional_interface = PETScVecInterface( rf.functional, comm=comm) - if action == HessianAction: # control -> control + if action == HESSIAN: # control -> control self.xinterface = self.control_interface self.yinterface = self.control_interface self.x = new_control_variable(rf) self.mult_impl = self._mult_hessian - elif action == AdjointAction: # functional -> control + elif action == ADJOINT: # functional -> control self.xinterface = self.functional_interface self.yinterface = self.control_interface self.x = rf.functional._ad_copy() self.mult_impl = self._mult_adjoint - elif action == TLMAction: # control -> functional + elif action == TLM: # control -> functional self.xinterface = self.control_interface self.yinterface = self.functional_interface @@ -247,14 +254,23 @@ def __init__(self, rf, action=HessianAction, *, self.action = action self._m = new_control_variable(rf) self._shift = 0 + self.always_update_tape = always_update_tape @classmethod def update(cls, obj, x, A, P): ctx = A.getPythonContext() ctx.control_interface.from_petsc(x, ctx._m) - ctx.update_tape_values(update_adjoint=True) + ctx.update_tape_values( + update_adjoint=(ctx.action == HESSIAN)) ctx._shift = 0 + pctx = P.getPythonContext() + if pctx is not ctx: + pctx.control_interface.from_petsc(x, pctx._m) + pctx.update_tape_values( + update_adjoint=(pctx.action == HESSIAN)) + pctx._shift = 0 + def shift(self, A, alpha): self._shift += alpha @@ -271,25 +287,29 @@ def mult(self, A, x, y): if self._shift != 0: y.axpy(self._shift, x) - @check_rf_action(action=HessianAction) + @check_rf_action(HESSIAN) def _mult_hessian(self, A, x): - # self.update_tape_values(update_adjoint=True) + if self.always_update_tape: + self.update_tape_values(update_adjoint=True) return self.rf.hessian( x, apply_riesz=self.apply_riesz) - @check_rf_action(TLMAction) + @check_rf_action(TLM) def _mult_tlm(self, A, x): - # self.update_tape_values(update_adjoint=False) + if self.always_update_tape: + self.update_tape_values(update_adjoint=False) return self.rf.tlm(x) - @check_rf_action(AdjointAction) + @check_rf_action(ADJOINT) def _mult_adjoint(self, A, x): - # self.update_tape_values(update_adjoint=False) + if self.always_update_tape: + self.update_tape_values(update_adjoint=False) return self.rf.derivative( adj_input=x, apply_riesz=self.apply_riesz) -def ReducedFunctionalMat(rf, action=HessianAction, *, apply_riesz=False, appctx=None, comm=None): +def ReducedFunctionalMat(rf, action=HESSIAN, *, apply_riesz=False, appctx=None, + always_update_tape=False, comm=None): """ PETSc.Mat to apply the action of a pyadjoint.ReducedFunctional. @@ -305,10 +325,14 @@ def ReducedFunctionalMat(rf, action=HessianAction, *, apply_riesz=False, appctx= apply_riesz (bool): Whether to apply the riesz map before returning the result of the action to PETSc. appctx (Optional[dict]): User provided context. + always_update_tape (bool): Whether to force reevaluation of the forward model every time + `mult` is called. If action is HESSIAN then this will also force the adjoint model to + be reevaluated at every call to `mult`. comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. """ ctx = ReducedFunctionalMatCtx( - rf, action, appctx=appctx, apply_riesz=apply_riesz, comm=comm) + rf, action, appctx=appctx, apply_riesz=apply_riesz, + always_update_tape=always_update_tape, comm=comm) ncol = ctx.xinterface.n Ncol = ctx.xinterface.N @@ -319,7 +343,7 @@ def ReducedFunctionalMat(rf, action=HessianAction, *, apply_riesz=False, appctx= mat = PETSc.Mat().createPython( ((nrow, Nrow), (ncol, Ncol)), ctx, comm=ctx.control_interface.comm) - if action == HessianAction: + if action == HESSIAN: mat.setOption(PETSc.Mat.Option.SYMMETRIC, True) mat.setUp() mat.assemble() @@ -367,17 +391,20 @@ class TAOObjective: Args: rf (AbstractReducedFunctional): Defines the forward, and used to compute derivative information. + always_update_tape (bool): Whether to force reevaluation of the forward model every time + gradient or hessian is called. If hessian is called then this will also force the + adjoint model to be reevaluated. """ - def __init__(self, rf): + def __init__(self, rf, always_update_tape=True): self._reduced_functional = rf + self.always_update_tape = always_update_tape @property def reduced_functional(self): """:class:`.AbstractReducedFunctional`. Defines the forward, and used to compute derivative information. """ - return self._reduced_functional def objective(self, m): @@ -389,7 +416,6 @@ def objective(self, m): Returns: AdjFloat: The value of the functional. """ - m = Enlist(m) J = self.reduced_functional(tuple(m_i._ad_copy() for m_i in m)) return J @@ -405,9 +431,9 @@ def gradient(self, m): OverloadedType or Sequence[OverloadedType]: The (dual space) derivative. """ - m = Enlist(m) - # J = self.reduced_functional(tuple(m_i._ad_copy() for m_i in m)) + if self.always_update_tape: + _ = self.reduced_functional(tuple(m_i._ad_copy() for m_i in m)) dJ = self.reduced_functional.derivative() return m.delist(dJ) @@ -444,8 +470,9 @@ def hessian(self, m, m_dot): m = Enlist(m) m_dot = Enlist(m_dot) - _ = self.reduced_functional(tuple(m_i._ad_copy() for m_i in m)) - _ = self.reduced_functional.derivative() + if self.always_update_tape: + _ = self.reduced_functional(tuple(m_i._ad_copy() for m_i in m)) + _ = self.reduced_functional.derivative() ddJ = self.reduced_functional.hessian(tuple(m_dot_i._ad_copy() for m_dot_i in m_dot)) return m.delist(ddJ) @@ -498,27 +525,25 @@ def __init__(self, problem, parameters, *, vec_interface = PETScVecInterface( tuple(control.control for control in rf.controls), comm=comm) - to_petsc, from_petsc = vec_interface.to_petsc, vec_interface.from_petsc - tao = PETSc.TAO().create(comm=comm) def objective(tao, x): m = new_control_variable(rf) - from_petsc(x, m) + vec_interface.from_petsc(x, m) J_val = tao_objective.objective(m) return J_val def gradient(tao, x, g): m = new_control_variable(rf) - from_petsc(x, m) + vec_interface.from_petsc(x, m) dJ = tao_objective.gradient(m) - to_petsc(g, dJ) + vec_interface.to_petsc(g, dJ) def objective_gradient(tao, x, g): m = new_control_variable(rf) - from_petsc(x, m) + vec_interface.from_petsc(x, m) J_val, dJ = tao_objective.objective_gradient(m) - to_petsc(g, dJ) + vec_interface.to_petsc(g, dJ) return J_val tao.setObjectiveGradient(objective_gradient) @@ -527,7 +552,7 @@ def objective_gradient(tao, x, g): hessian_mat = ReducedFunctionalMat( problem.reduced_functional, appctx=appctx, - action=HessianAction, comm=comm) + action=HESSIAN, comm=comm) tao.setHessian( hessian_mat.getPythonContext().update, @@ -550,8 +575,8 @@ def objective_gradient(tao, x, g): lb_vec = vec_interface.new_petsc() ub_vec = vec_interface.new_petsc() - to_petsc(lb_vec, lbs) - to_petsc(ub_vec, ubs) + vec_interface.to_petsc(lb_vec, lbs) + vec_interface.to_petsc(ub_vec, ubs) tao.setVariableBounds(lb_vec, ub_vec) petsctools.set_from_options( @@ -568,14 +593,8 @@ class InitialHessian: class InitialHessianPreconditioner: """:class:`petsc4py.PETSc.PC` context. """ - def apply(self, pc, x, y): - dJ = new_control_variable(rf, dual=True) - from_petsc(x, dJ) - assert len(tao_objective.reduced_functional.controls) == len(dJ) - dJ = tuple(control._ad_convert_riesz(dJ_i, riesz_map=control.riesz_map) - for control, dJ_i in zip(tao_objective.reduced_functional.controls, dJ)) - to_petsc(y, dJ) + Minv_mat.mult(x, y) # B_0_matrix is the initial Hessian approximation (following # Nocedal and Wright doi: 10.1007/978-0-387-40065-5 notation). This @@ -586,8 +605,8 @@ def apply(self, pc, x, y): B_0_matrix.setOption(PETSc.Mat.Option.SYMMETRIC, True) B_0_matrix.setUp() - B_0_matrix_pc = PETSc.PC().createPython(InitialHessianPreconditioner(), - comm=comm) + B_0_matrix_pc = PETSc.PC().createPython( + InitialHessianPreconditioner(), comm=comm) B_0_matrix_pc.setOperators(B_0_matrix) B_0_matrix_pc.setUp() From a3c3736401d0776c805fae55465736e6d4599f27 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 15 Oct 2025 13:00:04 +0100 Subject: [PATCH 04/19] Apply suggestions from code review - require named kwargs Co-authored-by: James R. Maddison --- pyadjoint/optimization/tao_solver.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 82a9df66..2dd0f7a4 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -130,7 +130,7 @@ def to_petsc(self, x, y): x_sub.restoreSubVector(iset, x_sub) -def new_control_variable(reduced_functional, dual=False): +def new_control_variable(reduced_functional, *, dual=False): """Return new variables suitable for storing a control value or its dual. Args: @@ -202,7 +202,7 @@ class ReducedFunctionalMatCtx: Hessian : V x U* -> V* | V -> V* Args: - rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. + rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. action (RFAction): Whether to apply the TLM, adjoint, or Hessian action. apply_riesz (bool): Whether to apply the riesz map before returning the result of the action to PETSc. @@ -276,7 +276,7 @@ def update(cls, obj, x, A, P): def shift(self, A, alpha): self._shift += alpha - def update_tape_values(self, update_adjoint=True): + def update_tape_values(self, *, update_adjoint=True): _ = self.rf(self._m) if update_adjoint: _ = self.rf.derivative(apply_riesz=False) @@ -322,7 +322,7 @@ def ReducedFunctionalMat(rf, action=HESSIAN, *, apply_riesz=False, appctx=None, Hessian : V x U* -> V* | V -> V* Args: - rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. + rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. action (RFAction): Whether to apply the TLM, adjoint, or Hessian action. apply_riesz (bool): Whether to apply the riesz map before returning the result of the action to PETSc. @@ -398,7 +398,7 @@ class TAOObjective: adjoint model to be reevaluated. """ - def __init__(self, rf, always_update_tape=True): + def __init__(self, rf, *, always_update_tape=True): self._reduced_functional = rf self.always_update_tape = always_update_tape From b69de0c782b4d06a70428cb46441d1e9d248fc39 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 15 Oct 2025 13:30:23 +0100 Subject: [PATCH 05/19] correct docstring for optimization.tao_solver.valid_comm Co-authored-by: James R. Maddison --- pyadjoint/optimization/tao_solver.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 2dd0f7a4..b7ae24d9 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -146,12 +146,12 @@ def new_control_variable(reduced_functional, *, dual=False): for control in reduced_functional.controls) -def get_valid_comm(comm): +def valid_comm(comm): """ - Return a valid communicator from a user provided (possibly null) comm. + Return a valid communicator from a user provided Comm or None. Args: - comm: Any[petsc4py.PETSc.Comm,mpi4py.MPI.Comm,None] + comm: Optional[Any[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]] Returns: mpi4py.MPI.Comm. COMM_WORLD if `comm is None`, otherwise `comm.tompi4py()`. @@ -217,7 +217,7 @@ def __init__(self, rf, action=HESSIAN, *, apply_riesz=False, appctx=None, always_update_tape=False, comm=PETSc.COMM_WORLD): - comm = get_valid_comm(comm) + comm = valid_comm(comm) self.rf = rf self.appctx = appctx @@ -354,7 +354,7 @@ def ReducedFunctionalMat(rf, action=HESSIAN, *, apply_riesz=False, appctx=None, class RieszMapMatCtx: def __init__(self, controls, comm=None): - comm = get_valid_comm(comm) + comm = valid_comm(comm) self.controls = Enlist(controls) self.vec_interface = PETScVecInterface( @@ -519,7 +519,7 @@ def __init__(self, problem, parameters, *, if problem.constraints is not None: raise NotImplementedError("Constraints not implemented") - comm = get_valid_comm(comm) + comm = valid_comm(comm) rf = problem.reduced_functional tao_objective = TAOObjective(rf) From b57591febfeda71857dd1fbbdf13c6d8a5f7bfd1 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 15 Oct 2025 14:10:54 +0100 Subject: [PATCH 06/19] split ReducedFunctionalMat into separate base classes for each action type --- pyadjoint/optimization/tao_solver.py | 248 +++++++++++++++++++++++---- 1 file changed, 212 insertions(+), 36 deletions(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index b7ae24d9..311795ff 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -1,4 +1,3 @@ -from functools import wraps from enum import Enum from numbers import Complex @@ -173,24 +172,6 @@ class RFAction(Enum): HESSIAN = 'hessian' -FORWARD = RFAction.FORWARD -TLM = RFAction.TLM -ADJOINT = RFAction.ADJOINT -HESSIAN = RFAction.HESSIAN - - -def check_rf_action(action): - def check_rf_action_decorator(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if self.action != action: - raise NotImplementedError( - f'Cannot apply {str(action)} action if {self.action=}') - return func(self, *args, **kwargs) - return wrapper - return check_rf_action_decorator - - class ReducedFunctionalMatCtx: """ PETSc.Mat Python context to apply the action of a pyadjoint.ReducedFunctional. @@ -213,7 +194,7 @@ class ReducedFunctionalMatCtx: comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. """ - def __init__(self, rf, action=HESSIAN, *, + def __init__(self, rf, action=RFAction.HESSIAN, *, apply_riesz=False, appctx=None, always_update_tape=False, comm=PETSc.COMM_WORLD): @@ -225,25 +206,25 @@ def __init__(self, rf, action=HESSIAN, *, tuple(c.control for c in rf.controls), comm=comm) self.apply_riesz = apply_riesz - if action in (ADJOINT, TLM): + if action in (RFAction.ADJOINT, RFAction.TLM): self.functional_interface = PETScVecInterface( rf.functional, comm=comm) - if action == HESSIAN: # control -> control + if action == RFAction.HESSIAN: # control -> control self.xinterface = self.control_interface self.yinterface = self.control_interface self.x = new_control_variable(rf) self.mult_impl = self._mult_hessian - elif action == ADJOINT: # functional -> control + elif action == RFAction.ADJOINT: # functional -> control self.xinterface = self.functional_interface self.yinterface = self.control_interface self.x = rf.functional._ad_copy() self.mult_impl = self._mult_adjoint - elif action == TLM: # control -> functional + elif action == RFAction.TLM: # control -> functional self.xinterface = self.control_interface self.yinterface = self.functional_interface @@ -263,14 +244,14 @@ def update(cls, obj, x, A, P): ctx = A.getPythonContext() ctx.control_interface.from_petsc(x, ctx._m) ctx.update_tape_values( - update_adjoint=(ctx.action == HESSIAN)) + update_adjoint=(ctx.action == RFAction.HESSIAN)) ctx._shift = 0 pctx = P.getPythonContext() if pctx is not ctx: pctx.control_interface.from_petsc(x, pctx._m) pctx.update_tape_values( - update_adjoint=(pctx.action == HESSIAN)) + update_adjoint=(pctx.action == RFAction.HESSIAN)) pctx._shift = 0 def shift(self, A, alpha): @@ -289,20 +270,17 @@ def mult(self, A, x, y): if self._shift != 0: y.axpy(self._shift, x) - @check_rf_action(HESSIAN) def _mult_hessian(self, A, x): if self.always_update_tape: self.update_tape_values(update_adjoint=True) return self.rf.hessian( x, apply_riesz=self.apply_riesz) - @check_rf_action(TLM) def _mult_tlm(self, A, x): if self.always_update_tape: self.update_tape_values(update_adjoint=False) return self.rf.tlm(x) - @check_rf_action(ADJOINT) def _mult_adjoint(self, A, x): if self.always_update_tape: self.update_tape_values(update_adjoint=False) @@ -310,7 +288,190 @@ def _mult_adjoint(self, A, x): adj_input=x, apply_riesz=self.apply_riesz) -def ReducedFunctionalMat(rf, action=HESSIAN, *, apply_riesz=False, appctx=None, +class ReducedFunctionalMatBase: + """ + PETSc.Mat Python context to apply the action of a pyadjoint.ReducedFunctional. + + If V is the control space and U is the functional space, each action has the following map: + Jhat : V -> U + TLM : V -> U + Adjoint : U* -> V* + Hessian : V x U* -> V* | V -> V* + + Args: + rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. + action (RFAction): Whether to apply the TLM, adjoint, or Hessian action. + apply_riesz (bool): Whether to apply the riesz map before returning the + result of the action to PETSc. + appctx (Optional[dict]): User provided context. + always_update_tape (bool): Whether to force reevaluation of the forward model every time + `mult` is called. If action is HESSIAN then this will also force the adjoint model to + be reevaluated at every call to `mult`. + needs_functional_interface: Whether to create a PETScVecInterface for the rf.functional. + comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. + """ + + def __init__(self, rf, action=RFAction.HESSIAN, *, + apply_riesz=False, appctx=None, + always_update_tape=False, + needs_functional_interface=False, + comm=PETSc.COMM_WORLD): + comm = valid_comm(comm) + + self.rf = rf + self.appctx = appctx + self.apply_riesz = apply_riesz + + self.control_interface = PETScVecInterface( + tuple(c.control for c in rf.controls), + comm=comm) + if needs_functional_interface: + self.functional_interface = PETScVecInterface( + rf.functional, comm=comm) + + self.action = action + self._m = new_control_variable(rf) + self._shift = 0 + self.always_update_tape = always_update_tape + + @classmethod + def update(cls, obj, x, A, P): + ctx = A.getPythonContext() + ctx.control_interface.from_petsc(x, ctx._m) + ctx.update_tape_values( + update_adjoint=(ctx.action == RFAction.HESSIAN)) + ctx._shift = 0 + + pctx = P.getPythonContext() + if pctx is not ctx: + pctx.control_interface.from_petsc(x, pctx._m) + pctx.update_tape_values( + update_adjoint=(pctx.action == RFAction.HESSIAN)) + pctx._shift = 0 + + def shift(self, A, alpha): + self._shift += alpha + + def update_tape_values(self, *, update_adjoint=True): + _ = self.rf(self._m) + if update_adjoint: + _ = self.rf.derivative(apply_riesz=False) + + def mult(self, A, x, y): + self.xinterface.from_petsc(x, self.x) + out = self.mult_impl(A, self.x) + self.yinterface.to_petsc(y, out) + + if self._shift != 0: + y.axpy(self._shift, x) + + +class ReducedFunctionalHessianMat(ReducedFunctionalMatBase): + """ + PETSc.Mat Python context to apply the Hessian action of a pyadjoint.ReducedFunctional. + + If V is the control space and U is the functional space, the Hessian action has the following map: + Jhat : V -> U + Hessian : V x U* -> V* | V -> V* + + Args: + rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. + apply_riesz (bool): Whether to apply the riesz map before returning the result of the + action to PETSc. + appctx (Optional[dict]): User provided context. + always_update_tape (bool): Whether to force reevaluation of the forward and adjoint models + every time `mult` is called. + comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. + """ + + def __init__(self, rf, *, apply_riesz=False, appctx=None, + always_update_tape=False, comm=PETSc.COMM_WORLD): + + super().__init__(rf, RFAction.HESSIAN, apply_riesz=apply_riesz, + appctx=appctx, needs_functional_interface=False, + always_update_tape=always_update_tape, comm=comm) + + self.xinterface = self.control_interface + self.yinterface = self.control_interface + self.x = new_control_variable(rf) + + def mult_impl(self, A, x): + if self.always_update_tape: + self.update_tape_values(update_adjoint=True) + return self.rf.hessian( + x, apply_riesz=self.apply_riesz) + + +class ReducedFunctionalAdjointMat(ReducedFunctionalMatBase): + """ + PETSc.Mat Python context to apply the adjoint action of a pyadjoint.ReducedFunctional. + + If V is the control space and U is the functional space, the adjoint action has the following map: + Jhat : V -> U + Adjoint : U* -> V* + + Args: + rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. + apply_riesz (bool): Whether to apply the riesz map before returning the result of the + action to PETSc. + appctx (Optional[dict]): User provided context. + always_update_tape (bool): Whether to force reevaluation of the forward model every time + `mult` is called. + comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. + """ + + def __init__(self, rf, *, apply_riesz=False, appctx=None, + always_update_tape=False, comm=PETSc.COMM_WORLD): + + super().__init__(rf, RFAction.HESSIAN, apply_riesz=apply_riesz, + appctx=appctx, needs_functional_interface=True, + always_update_tape=always_update_tape, comm=comm) + + self.xinterface = self.functional_interface + self.yinterface = self.control_interface + self.x = rf.functional._ad_copy() + + def mult_impl(self, A, x): + if self.always_update_tape: + self.update_tape_values(update_adjoint=False) + return self.rf.derivative( + adj_input=x, apply_riesz=self.apply_riesz) + + +class ReducedFunctionalTLMMat(ReducedFunctionalMatBase): + """ + PETSc.Mat Python context to apply the tangent linear action of a pyadjoint.ReducedFunctional. + + If V is the control space and U is the functional space, the tangent linear action has the following map: + Jhat : V -> U + TLM : V -> U + + Args: + rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. + appctx (Optional[dict]): User provided context. + always_update_tape (bool): Whether to force reevaluation of the forward model every time + `mult` is called. + comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. + """ + + def __init__(self, rf, *, apply_riesz=False, appctx=None, + always_update_tape=False, comm=PETSc.COMM_WORLD): + + super().__init__(rf, RFAction.HESSIAN, apply_riesz=apply_riesz, + appctx=appctx, needs_functional_interface=True, + always_update_tape=always_update_tape, comm=comm) + + self.xinterface = self.control_interface + self.yinterface = self.functional_interface + self.x = new_control_variable(rf) + + def mult_impl(self, A, x): + if self.always_update_tape: + self.update_tape_values(update_adjoint=False) + return self.rf.tlm(x) + + +def ReducedFunctionalMat(rf, action=RFAction.HESSIAN, *, apply_riesz=False, appctx=None, always_update_tape=False, comm=None): """ PETSc.Mat to apply the action of a pyadjoint.ReducedFunctional. @@ -325,16 +486,31 @@ def ReducedFunctionalMat(rf, action=HESSIAN, *, apply_riesz=False, appctx=None, rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. action (RFAction): Whether to apply the TLM, adjoint, or Hessian action. apply_riesz (bool): Whether to apply the riesz map before returning the - result of the action to PETSc. + result of the action to PETSc. Ignored if action is TLM. appctx (Optional[dict]): User provided context. always_update_tape (bool): Whether to force reevaluation of the forward model every time `mult` is called. If action is HESSIAN then this will also force the adjoint model to be reevaluated at every call to `mult`. comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. """ - ctx = ReducedFunctionalMatCtx( - rf, action, appctx=appctx, apply_riesz=apply_riesz, - always_update_tape=always_update_tape, comm=comm) + if action == RFAction.HESSIAN: + ctx = ReducedFunctionalHessianMat( + rf, appctx=appctx, apply_riesz=apply_riesz, + always_update_tape=always_update_tape, comm=comm) + + elif action == RFAction.ADJOINT: + ctx = ReducedFunctionalAdjointMat( + rf, appctx=appctx, apply_riesz=apply_riesz, + always_update_tape=always_update_tape, comm=comm) + + elif action == RFAction.TLM: + ctx = ReducedFunctionalTLMMat( + rf, appctx=appctx, + always_update_tape=always_update_tape, comm=comm) + + else: + raise ValueError( + 'Unrecognised RFAction: {action}.') ncol = ctx.xinterface.n Ncol = ctx.xinterface.N @@ -345,7 +521,7 @@ def ReducedFunctionalMat(rf, action=HESSIAN, *, apply_riesz=False, appctx=None, mat = PETSc.Mat().createPython( ((nrow, Nrow), (ncol, Ncol)), ctx, comm=ctx.control_interface.comm) - if action == HESSIAN: + if action == RFAction.HESSIAN: mat.setOption(PETSc.Mat.Option.SYMMETRIC, True) mat.setUp() mat.assemble() @@ -554,7 +730,7 @@ def objective_gradient(tao, x, g): hessian_mat = ReducedFunctionalMat( problem.reduced_functional, appctx=appctx, - action=HESSIAN, comm=comm) + action=RFAction.HESSIAN, comm=comm) tao.setHessian( hessian_mat.getPythonContext().update, From f870144df07654835d861ad01c0dd3faf6afc052 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 15 Oct 2025 14:13:19 +0100 Subject: [PATCH 07/19] remove apply_riesz kwarg for ReducedFunctionalTLMMat --- pyadjoint/optimization/tao_solver.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 311795ff..a1579108 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -454,11 +454,10 @@ class ReducedFunctionalTLMMat(ReducedFunctionalMatBase): comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. """ - def __init__(self, rf, *, apply_riesz=False, appctx=None, - always_update_tape=False, comm=PETSc.COMM_WORLD): + def __init__(self, rf, *, appctx=None, always_update_tape=False, comm=PETSc.COMM_WORLD): - super().__init__(rf, RFAction.HESSIAN, apply_riesz=apply_riesz, - appctx=appctx, needs_functional_interface=True, + super().__init__(rf, RFAction.HESSIAN, appctx=appctx, + needs_functional_interface=True, always_update_tape=always_update_tape, comm=comm) self.xinterface = self.control_interface From fbc8cac7f929f2011266a467f6c98238b5f61b82 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 16 Oct 2025 14:00:01 +0100 Subject: [PATCH 08/19] do not assume that tao pmat is ReducedFunctionalMat --- pyadjoint/optimization/tao_solver.py | 123 --------------------------- 1 file changed, 123 deletions(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index a1579108..79f2cc14 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -172,122 +172,6 @@ class RFAction(Enum): HESSIAN = 'hessian' -class ReducedFunctionalMatCtx: - """ - PETSc.Mat Python context to apply the action of a pyadjoint.ReducedFunctional. - - If V is the control space and U is the functional space, each action has the following map: - Jhat : V -> U - TLM : V -> U - Adjoint : U* -> V* - Hessian : V x U* -> V* | V -> V* - - Args: - rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. - action (RFAction): Whether to apply the TLM, adjoint, or Hessian action. - apply_riesz (bool): Whether to apply the riesz map before returning the - result of the action to PETSc. - appctx (Optional[dict]): User provided context. - always_update_tape (bool): Whether to force reevaluation of the forward model every time - `mult` is called. If action is HESSIAN then this will also force the adjoint model to - be reevaluated at every call to `mult`. - comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. - """ - - def __init__(self, rf, action=RFAction.HESSIAN, *, - apply_riesz=False, appctx=None, - always_update_tape=False, - comm=PETSc.COMM_WORLD): - comm = valid_comm(comm) - - self.rf = rf - self.appctx = appctx - self.control_interface = PETScVecInterface( - tuple(c.control for c in rf.controls), - comm=comm) - self.apply_riesz = apply_riesz - if action in (RFAction.ADJOINT, RFAction.TLM): - self.functional_interface = PETScVecInterface( - rf.functional, comm=comm) - - if action == RFAction.HESSIAN: # control -> control - self.xinterface = self.control_interface - self.yinterface = self.control_interface - - self.x = new_control_variable(rf) - self.mult_impl = self._mult_hessian - - elif action == RFAction.ADJOINT: # functional -> control - self.xinterface = self.functional_interface - self.yinterface = self.control_interface - - self.x = rf.functional._ad_copy() - self.mult_impl = self._mult_adjoint - - elif action == RFAction.TLM: # control -> functional - self.xinterface = self.control_interface - self.yinterface = self.functional_interface - - self.x = new_control_variable(rf) - self.mult_impl = self._mult_tlm - else: - raise ValueError( - 'Unrecognised {action = }.') - - self.action = action - self._m = new_control_variable(rf) - self._shift = 0 - self.always_update_tape = always_update_tape - - @classmethod - def update(cls, obj, x, A, P): - ctx = A.getPythonContext() - ctx.control_interface.from_petsc(x, ctx._m) - ctx.update_tape_values( - update_adjoint=(ctx.action == RFAction.HESSIAN)) - ctx._shift = 0 - - pctx = P.getPythonContext() - if pctx is not ctx: - pctx.control_interface.from_petsc(x, pctx._m) - pctx.update_tape_values( - update_adjoint=(pctx.action == RFAction.HESSIAN)) - pctx._shift = 0 - - def shift(self, A, alpha): - self._shift += alpha - - def update_tape_values(self, *, update_adjoint=True): - _ = self.rf(self._m) - if update_adjoint: - _ = self.rf.derivative(apply_riesz=False) - - def mult(self, A, x, y): - self.xinterface.from_petsc(x, self.x) - out = self.mult_impl(A, self.x) - self.yinterface.to_petsc(y, out) - - if self._shift != 0: - y.axpy(self._shift, x) - - def _mult_hessian(self, A, x): - if self.always_update_tape: - self.update_tape_values(update_adjoint=True) - return self.rf.hessian( - x, apply_riesz=self.apply_riesz) - - def _mult_tlm(self, A, x): - if self.always_update_tape: - self.update_tape_values(update_adjoint=False) - return self.rf.tlm(x) - - def _mult_adjoint(self, A, x): - if self.always_update_tape: - self.update_tape_values(update_adjoint=False) - return self.rf.derivative( - adj_input=x, apply_riesz=self.apply_riesz) - - class ReducedFunctionalMatBase: """ PETSc.Mat Python context to apply the action of a pyadjoint.ReducedFunctional. @@ -342,13 +226,6 @@ def update(cls, obj, x, A, P): update_adjoint=(ctx.action == RFAction.HESSIAN)) ctx._shift = 0 - pctx = P.getPythonContext() - if pctx is not ctx: - pctx.control_interface.from_petsc(x, pctx._m) - pctx.update_tape_values( - update_adjoint=(pctx.action == RFAction.HESSIAN)) - pctx._shift = 0 - def shift(self, A, alpha): self._shift += alpha From 0cd9c9f58e65a5349c6b85e67ddbc1dca4a3c210 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 21 Oct 2025 14:28:25 +0100 Subject: [PATCH 09/19] remove old ReducedFunctionalMatCtx implementation --- pyadjoint/optimization/tao_solver.py | 116 --------------------------- 1 file changed, 116 deletions(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index a1579108..f5d6c332 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -172,122 +172,6 @@ class RFAction(Enum): HESSIAN = 'hessian' -class ReducedFunctionalMatCtx: - """ - PETSc.Mat Python context to apply the action of a pyadjoint.ReducedFunctional. - - If V is the control space and U is the functional space, each action has the following map: - Jhat : V -> U - TLM : V -> U - Adjoint : U* -> V* - Hessian : V x U* -> V* | V -> V* - - Args: - rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. - action (RFAction): Whether to apply the TLM, adjoint, or Hessian action. - apply_riesz (bool): Whether to apply the riesz map before returning the - result of the action to PETSc. - appctx (Optional[dict]): User provided context. - always_update_tape (bool): Whether to force reevaluation of the forward model every time - `mult` is called. If action is HESSIAN then this will also force the adjoint model to - be reevaluated at every call to `mult`. - comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. - """ - - def __init__(self, rf, action=RFAction.HESSIAN, *, - apply_riesz=False, appctx=None, - always_update_tape=False, - comm=PETSc.COMM_WORLD): - comm = valid_comm(comm) - - self.rf = rf - self.appctx = appctx - self.control_interface = PETScVecInterface( - tuple(c.control for c in rf.controls), - comm=comm) - self.apply_riesz = apply_riesz - if action in (RFAction.ADJOINT, RFAction.TLM): - self.functional_interface = PETScVecInterface( - rf.functional, comm=comm) - - if action == RFAction.HESSIAN: # control -> control - self.xinterface = self.control_interface - self.yinterface = self.control_interface - - self.x = new_control_variable(rf) - self.mult_impl = self._mult_hessian - - elif action == RFAction.ADJOINT: # functional -> control - self.xinterface = self.functional_interface - self.yinterface = self.control_interface - - self.x = rf.functional._ad_copy() - self.mult_impl = self._mult_adjoint - - elif action == RFAction.TLM: # control -> functional - self.xinterface = self.control_interface - self.yinterface = self.functional_interface - - self.x = new_control_variable(rf) - self.mult_impl = self._mult_tlm - else: - raise ValueError( - 'Unrecognised {action = }.') - - self.action = action - self._m = new_control_variable(rf) - self._shift = 0 - self.always_update_tape = always_update_tape - - @classmethod - def update(cls, obj, x, A, P): - ctx = A.getPythonContext() - ctx.control_interface.from_petsc(x, ctx._m) - ctx.update_tape_values( - update_adjoint=(ctx.action == RFAction.HESSIAN)) - ctx._shift = 0 - - pctx = P.getPythonContext() - if pctx is not ctx: - pctx.control_interface.from_petsc(x, pctx._m) - pctx.update_tape_values( - update_adjoint=(pctx.action == RFAction.HESSIAN)) - pctx._shift = 0 - - def shift(self, A, alpha): - self._shift += alpha - - def update_tape_values(self, *, update_adjoint=True): - _ = self.rf(self._m) - if update_adjoint: - _ = self.rf.derivative(apply_riesz=False) - - def mult(self, A, x, y): - self.xinterface.from_petsc(x, self.x) - out = self.mult_impl(A, self.x) - self.yinterface.to_petsc(y, out) - - if self._shift != 0: - y.axpy(self._shift, x) - - def _mult_hessian(self, A, x): - if self.always_update_tape: - self.update_tape_values(update_adjoint=True) - return self.rf.hessian( - x, apply_riesz=self.apply_riesz) - - def _mult_tlm(self, A, x): - if self.always_update_tape: - self.update_tape_values(update_adjoint=False) - return self.rf.tlm(x) - - def _mult_adjoint(self, A, x): - if self.always_update_tape: - self.update_tape_values(update_adjoint=False) - return self.rf.derivative( - adj_input=x, apply_riesz=self.apply_riesz) - - class ReducedFunctionalMatBase: """ PETSc.Mat Python context to apply the action of a pyadjoint.ReducedFunctional. From 8a84e7cc4082dd9ef9d1fdfd4519fa0dc96d312b Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Thu, 23 Oct 2025 13:27:22 +0100 Subject: [PATCH 10/19] RFMat naming and docstring updates from code review --- pyadjoint/optimization/tao_solver.py | 34 ++++++++++++++++------------ 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 79f2cc14..646689eb 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -162,7 +162,7 @@ def valid_comm(comm): return comm -class RFAction(Enum): +class RFOperation(Enum): """ The type of linear action that a ReducedFunctionalMat should apply. """ @@ -184,18 +184,22 @@ class ReducedFunctionalMatBase: Args: rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. - action (RFAction): Whether to apply the TLM, adjoint, or Hessian action. + action (RFOperation): Whether to apply the TLM, adjoint, or Hessian action. apply_riesz (bool): Whether to apply the riesz map before returning the result of the action to PETSc. appctx (Optional[dict]): User provided context. always_update_tape (bool): Whether to force reevaluation of the forward model every time `mult` is called. If action is HESSIAN then this will also force the adjoint model to be reevaluated at every call to `mult`. + The default is False because PETSc will call the update method before each KSP solve + and we evaluate the forward model there (and adjoint in the case of RFOperation.HESSIAN). needs_functional_interface: Whether to create a PETScVecInterface for the rf.functional. + This is required when RFOperation is ADJOINT or TLM, because then this Mat is rectangular + rather than square as in the case of HESSIAN. comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. """ - def __init__(self, rf, action=RFAction.HESSIAN, *, + def __init__(self, rf, action=RFOperation.HESSIAN, *, apply_riesz=False, appctx=None, always_update_tape=False, needs_functional_interface=False, @@ -223,7 +227,7 @@ def update(cls, obj, x, A, P): ctx = A.getPythonContext() ctx.control_interface.from_petsc(x, ctx._m) ctx.update_tape_values( - update_adjoint=(ctx.action == RFAction.HESSIAN)) + update_adjoint=(ctx.action == RFOperation.HESSIAN)) ctx._shift = 0 def shift(self, A, alpha): @@ -264,7 +268,7 @@ class ReducedFunctionalHessianMat(ReducedFunctionalMatBase): def __init__(self, rf, *, apply_riesz=False, appctx=None, always_update_tape=False, comm=PETSc.COMM_WORLD): - super().__init__(rf, RFAction.HESSIAN, apply_riesz=apply_riesz, + super().__init__(rf, RFOperation.HESSIAN, apply_riesz=apply_riesz, appctx=appctx, needs_functional_interface=False, always_update_tape=always_update_tape, comm=comm) @@ -300,7 +304,7 @@ class ReducedFunctionalAdjointMat(ReducedFunctionalMatBase): def __init__(self, rf, *, apply_riesz=False, appctx=None, always_update_tape=False, comm=PETSc.COMM_WORLD): - super().__init__(rf, RFAction.HESSIAN, apply_riesz=apply_riesz, + super().__init__(rf, RFOperation.HESSIAN, apply_riesz=apply_riesz, appctx=appctx, needs_functional_interface=True, always_update_tape=always_update_tape, comm=comm) @@ -333,7 +337,7 @@ class ReducedFunctionalTLMMat(ReducedFunctionalMatBase): def __init__(self, rf, *, appctx=None, always_update_tape=False, comm=PETSc.COMM_WORLD): - super().__init__(rf, RFAction.HESSIAN, appctx=appctx, + super().__init__(rf, RFOperation.HESSIAN, appctx=appctx, needs_functional_interface=True, always_update_tape=always_update_tape, comm=comm) @@ -347,7 +351,7 @@ def mult_impl(self, A, x): return self.rf.tlm(x) -def ReducedFunctionalMat(rf, action=RFAction.HESSIAN, *, apply_riesz=False, appctx=None, +def ReducedFunctionalMat(rf, action=RFOperation.HESSIAN, *, apply_riesz=False, appctx=None, always_update_tape=False, comm=None): """ PETSc.Mat to apply the action of a pyadjoint.ReducedFunctional. @@ -360,7 +364,7 @@ def ReducedFunctionalMat(rf, action=RFAction.HESSIAN, *, apply_riesz=False, appc Args: rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. - action (RFAction): Whether to apply the TLM, adjoint, or Hessian action. + action (RFOperation): Whether to apply the TLM, adjoint, or Hessian action. apply_riesz (bool): Whether to apply the riesz map before returning the result of the action to PETSc. Ignored if action is TLM. appctx (Optional[dict]): User provided context. @@ -369,24 +373,24 @@ def ReducedFunctionalMat(rf, action=RFAction.HESSIAN, *, apply_riesz=False, appc be reevaluated at every call to `mult`. comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. """ - if action == RFAction.HESSIAN: + if action == RFOperation.HESSIAN: ctx = ReducedFunctionalHessianMat( rf, appctx=appctx, apply_riesz=apply_riesz, always_update_tape=always_update_tape, comm=comm) - elif action == RFAction.ADJOINT: + elif action == RFOperation.ADJOINT: ctx = ReducedFunctionalAdjointMat( rf, appctx=appctx, apply_riesz=apply_riesz, always_update_tape=always_update_tape, comm=comm) - elif action == RFAction.TLM: + elif action == RFOperation.TLM: ctx = ReducedFunctionalTLMMat( rf, appctx=appctx, always_update_tape=always_update_tape, comm=comm) else: raise ValueError( - 'Unrecognised RFAction: {action}.') + f'Unrecognised RFOperation: {action}.') ncol = ctx.xinterface.n Ncol = ctx.xinterface.N @@ -397,7 +401,7 @@ def ReducedFunctionalMat(rf, action=RFAction.HESSIAN, *, apply_riesz=False, appc mat = PETSc.Mat().createPython( ((nrow, Nrow), (ncol, Ncol)), ctx, comm=ctx.control_interface.comm) - if action == RFAction.HESSIAN: + if action == RFOperation.HESSIAN: mat.setOption(PETSc.Mat.Option.SYMMETRIC, True) mat.setUp() mat.assemble() @@ -606,7 +610,7 @@ def objective_gradient(tao, x, g): hessian_mat = ReducedFunctionalMat( problem.reduced_functional, appctx=appctx, - action=RFAction.HESSIAN, comm=comm) + action=RFOperation.HESSIAN, comm=comm) tao.setHessian( hessian_mat.getPythonContext().update, From 6109e9f4a354c99bc553b1743d75f1f141ea7496 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Thu, 23 Oct 2025 13:59:20 +0100 Subject: [PATCH 11/19] ReducedFunctionalMat.multHermitian --- pyadjoint/optimization/tao_solver.py | 38 ++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 646689eb..e334512e 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -246,6 +246,30 @@ def mult(self, A, x, y): if self._shift != 0: y.axpy(self._shift, x) + def multHermitian(self, A, x, y): + self.yinterface.from_petsc(x, self.x) + out = self.multHermitian_impl(A, self.x) + self.xinterface.to_petsc(y, out) + + if self._shift != 0: + y.axpy(self._shift, x) + + def mult_impl(self, A, x): + """ + This method must be overriden. + + Provides the implementation of a particular type of ReducedFunctional action. + """ + raise NotImplementedError + + def multHermitian_impl(self, A, y): + """ + This method must be overriden. + + Provides the implementation of the Hermitian of a particular type of ReducedFunctional action. + """ + raise NotImplementedError + class ReducedFunctionalHessianMat(ReducedFunctionalMatBase): """ @@ -282,6 +306,9 @@ def mult_impl(self, A, x): return self.rf.hessian( x, apply_riesz=self.apply_riesz) + def multHermitian_impl(self, A, x): + return self.mult_impl(A, x) + class ReducedFunctionalAdjointMat(ReducedFunctionalMatBase): """ @@ -318,6 +345,11 @@ def mult_impl(self, A, x): return self.rf.derivative( adj_input=x, apply_riesz=self.apply_riesz) + def multHermitian_impl(self, A, x): + if self.always_update_tape: + self.update_tape_values(update_adjoint=False) + return self.rf.tlm(x) + class ReducedFunctionalTLMMat(ReducedFunctionalMatBase): """ @@ -350,6 +382,12 @@ def mult_impl(self, A, x): self.update_tape_values(update_adjoint=False) return self.rf.tlm(x) + def multHermitian_impl(self, A, x): + if self.always_update_tape: + self.update_tape_values(update_adjoint=False) + return self.rf.derivative( + adj_input=x, apply_riesz=self.apply_riesz) + def ReducedFunctionalMat(rf, action=RFOperation.HESSIAN, *, apply_riesz=False, appctx=None, always_update_tape=False, comm=None): From ba5b56b3615a17265f2095dc8f40289784a2212f Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 24 Oct 2025 14:49:29 +0100 Subject: [PATCH 12/19] Update docstring pyadjoint/optimization/tao_solver.py Co-authored-by: Connor Ward --- pyadjoint/optimization/tao_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index e334512e..1fdcd0e6 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -286,7 +286,7 @@ class ReducedFunctionalHessianMat(ReducedFunctionalMatBase): appctx (Optional[dict]): User provided context. always_update_tape (bool): Whether to force reevaluation of the forward and adjoint models every time `mult` is called. - comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. + comm (Optional[Union[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]]): Communicator that the rf is defined over. """ def __init__(self, rf, *, apply_riesz=False, appctx=None, From 5d99fcb916e06cdf4ce35864c8167b784253d522 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 24 Oct 2025 14:49:56 +0100 Subject: [PATCH 13/19] Update pyadjoint/optimization/tao_solver.py Co-authored-by: Connor Ward --- pyadjoint/optimization/tao_solver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 1fdcd0e6..a9d06d72 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -140,7 +140,6 @@ def new_control_variable(reduced_functional, *, dual=False): Returns: tuple[OverloadedType]: New variables suitable for storing a control value. """ - return tuple(control._ad_init_zero(dual=dual) for control in reduced_functional.controls) From 9bb846f433719669c4e83e8a3ec3b4d224e8b49a Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 24 Oct 2025 14:50:14 +0100 Subject: [PATCH 14/19] Update pyadjoint/optimization/tao_solver.py Co-authored-by: Connor Ward --- pyadjoint/optimization/tao_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index a9d06d72..b1679cc2 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -152,7 +152,7 @@ def valid_comm(comm): comm: Optional[Any[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]] Returns: - mpi4py.MPI.Comm. COMM_WORLD if `comm is None`, otherwise `comm.tompi4py()`. + mpi4py.MPI.Comm.COMM_WORLD if `comm is None`, otherwise `comm.tompi4py()`. """ if comm is None: comm = PETSc.COMM_WORLD From e63affb06dc7a29262be8efed2441196f8e14732 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 24 Oct 2025 16:29:22 +0100 Subject: [PATCH 15/19] remove action attribute from ReducedFunctionalMat base class --- pyadjoint/optimization/tao_solver.py | 36 +++++++++++++--------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index b1679cc2..d93ee60d 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -182,25 +182,24 @@ class ReducedFunctionalMatBase: Hessian : V x U* -> V* | V -> V* Args: - rf (ReducedFunctional): Defines the forward model, and used to compute operator actions. - action (RFOperation): Whether to apply the TLM, adjoint, or Hessian action. + rf (ReducedFunctional): Defines the forward model. Used to compute Mat actions. apply_riesz (bool): Whether to apply the riesz map before returning the result of the action to PETSc. appctx (Optional[dict]): User provided context. - always_update_tape (bool): Whether to force reevaluation of the forward model every time - `mult` is called. If action is HESSIAN then this will also force the adjoint model to - be reevaluated at every call to `mult`. - The default is False because PETSc will call the update method before each KSP solve - and we evaluate the forward model there (and adjoint in the case of RFOperation.HESSIAN). - needs_functional_interface: Whether to create a PETScVecInterface for the rf.functional. - This is required when RFOperation is ADJOINT or TLM, because then this Mat is rectangular - rather than square as in the case of HESSIAN. + always_update_tape (bool): Whether to force reevaluation of the forward model + every time `mult` is called. If needs_adjoint_update then this will also force + the adjoint model to be reevaluated at every call to `mult`. + update_adjoint (bool): Whether to update the adjoint as well as the forward + model when updating the Mat. If True then `rf.derivative` will be called by + the `update` method. Required for Hessian but not for TLM or Adjoint Mats. + needs_functional_interface: Whether to create a PETScVecInterface for rf.functional. comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. """ - def __init__(self, rf, action=RFOperation.HESSIAN, *, + def __init__(self, rf, *, apply_riesz=False, appctx=None, always_update_tape=False, + update_adjoint=False, needs_functional_interface=False, comm=PETSc.COMM_WORLD): comm = valid_comm(comm) @@ -216,7 +215,7 @@ def __init__(self, rf, action=RFOperation.HESSIAN, *, self.functional_interface = PETScVecInterface( rf.functional, comm=comm) - self.action = action + self.update_adjoint = update_adjoint self._m = new_control_variable(rf) self._shift = 0 self.always_update_tape = always_update_tape @@ -225,8 +224,7 @@ def __init__(self, rf, action=RFOperation.HESSIAN, *, def update(cls, obj, x, A, P): ctx = A.getPythonContext() ctx.control_interface.from_petsc(x, ctx._m) - ctx.update_tape_values( - update_adjoint=(ctx.action == RFOperation.HESSIAN)) + ctx.update_tape_values(update_adjoint=self.update_adjoint) ctx._shift = 0 def shift(self, A, alpha): @@ -291,8 +289,8 @@ class ReducedFunctionalHessianMat(ReducedFunctionalMatBase): def __init__(self, rf, *, apply_riesz=False, appctx=None, always_update_tape=False, comm=PETSc.COMM_WORLD): - super().__init__(rf, RFOperation.HESSIAN, apply_riesz=apply_riesz, - appctx=appctx, needs_functional_interface=False, + super().__init__(rf, apply_riesz=apply_riesz, appctx=appctx, + needs_functional_interface=False, update_adjoint=True, always_update_tape=always_update_tape, comm=comm) self.xinterface = self.control_interface @@ -330,8 +328,8 @@ class ReducedFunctionalAdjointMat(ReducedFunctionalMatBase): def __init__(self, rf, *, apply_riesz=False, appctx=None, always_update_tape=False, comm=PETSc.COMM_WORLD): - super().__init__(rf, RFOperation.HESSIAN, apply_riesz=apply_riesz, - appctx=appctx, needs_functional_interface=True, + super().__init__(rf, apply_riesz=apply_riesz, appctx=appctx, + needs_functional_interface=True, update_adjoint=False, always_update_tape=always_update_tape, comm=comm) self.xinterface = self.functional_interface @@ -368,7 +366,7 @@ class ReducedFunctionalTLMMat(ReducedFunctionalMatBase): def __init__(self, rf, *, appctx=None, always_update_tape=False, comm=PETSc.COMM_WORLD): - super().__init__(rf, RFOperation.HESSIAN, appctx=appctx, + super().__init__(rf, appctx=appctx, update_adjoint=False, needs_functional_interface=True, always_update_tape=always_update_tape, comm=comm) From d2d720d3360d7110e78dcaf5cc32b1c7335d74a5 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 24 Oct 2025 17:13:45 +0100 Subject: [PATCH 16/19] petsc mat docstrings and error messages --- pyadjoint/optimization/tao_solver.py | 37 ++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index d93ee60d..03a8ceb3 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -256,16 +256,28 @@ def mult_impl(self, A, x): This method must be overriden. Provides the implementation of a particular type of ReducedFunctional action. + + Args: + A (PETSc.Mat): The Mat that this python context is attached to. + x (Union[OverloadedType, list[OverloadedType]]): An element in either the control + or functional space of the ReducedFunctional that this Mat will act on. """ - raise NotImplementedError + raise NotImplementedError( + "Must provide implementation of the action of this matrix on an OverloadedType") def multHermitian_impl(self, A, y): """ This method must be overriden. Provides the implementation of the Hermitian of a particular type of ReducedFunctional action. + + Args: + A (PETSc.Mat): The Mat that this python context is attached to. + y (Union[OverloadedType, list[OverloadedType]]): An element in either the control + or functional space of the ReducedFunctional that this Mat will act on. """ - raise NotImplementedError + raise NotImplementedError( + "Must provide implementation of the Hermitian action of this matrix on an OverloadedType") class ReducedFunctionalHessianMat(ReducedFunctionalMatBase): @@ -444,6 +456,16 @@ def ReducedFunctionalMat(rf, action=RFOperation.HESSIAN, *, apply_riesz=False, a class RieszMapMatCtx: + """ + PETSc.Mat to apply the Riesz map to an element in the dual of the control space. + + If V is the control space then this has the followiung signature: + RieszMap : V* -> V + + Args: + controls (Union[Control,list[Control]]): The controls defining the primal control space. + comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the controls are defined over. + """ def __init__(self, controls, comm=None): comm = valid_comm(comm) @@ -463,6 +485,17 @@ def mult(self, mat, x, y): def RieszMapMat(controls, symmetric=True, comm=None): + """ + PETSc.Mat to apply the Riesz map to an element in the dual of the control space. + + If V is the control space then this has the followiung signature: + RieszMap : V* -> V + + Args: + controls (Union[Control,list[Control]]): The controls defining the primal control space. + symmetric (bool): Whether the Riesz map attached to the Control is symmetric. + comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the controls are defined over. + """ ctx = RieszMapMatCtx(controls, comm=comm) n = ctx.vec_interface.n From 3e8662d63e9f858ebb3fb2f598e8ffcea86dfb57 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 24 Oct 2025 17:27:21 +0100 Subject: [PATCH 17/19] petsc mat docstrings correction --- pyadjoint/optimization/tao_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 03a8ceb3..7c245e54 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -457,7 +457,7 @@ def ReducedFunctionalMat(rf, action=RFOperation.HESSIAN, *, apply_riesz=False, a class RieszMapMatCtx: """ - PETSc.Mat to apply the Riesz map to an element in the dual of the control space. + PETSc.Mat Python context to apply the Riesz map to an element in the dual of the control space. If V is the control space then this has the followiung signature: RieszMap : V* -> V From 96f75e0baafc4e14d6e241b9f965c73db86784ec Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 24 Oct 2025 17:45:30 +0100 Subject: [PATCH 18/19] petsc mat update fix --- pyadjoint/optimization/tao_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 7c245e54..3cb97463 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -224,7 +224,7 @@ def __init__(self, rf, *, def update(cls, obj, x, A, P): ctx = A.getPythonContext() ctx.control_interface.from_petsc(x, ctx._m) - ctx.update_tape_values(update_adjoint=self.update_adjoint) + ctx.update_tape_values(update_adjoint=ctx.update_adjoint) ctx._shift = 0 def shift(self, A, alpha): From 3a214ebaa13fc3efa7ba42443a1d8214d1d8e454 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 31 Oct 2025 16:42:18 +0000 Subject: [PATCH 19/19] make update_adjoint a class method for ReducedFunctionalMat --- pyadjoint/optimization/tao_solver.py | 40 ++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 3cb97463..53812af8 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -181,6 +181,11 @@ class ReducedFunctionalMatBase: Adjoint : U* -> V* Hessian : V x U* -> V* | V -> V* + Child classes must implement: + - mult_impl + - multHermitian_impl + - update_adjoint + Args: rf (ReducedFunctional): Defines the forward model. Used to compute Mat actions. apply_riesz (bool): Whether to apply the riesz map before returning the @@ -189,9 +194,6 @@ class ReducedFunctionalMatBase: always_update_tape (bool): Whether to force reevaluation of the forward model every time `mult` is called. If needs_adjoint_update then this will also force the adjoint model to be reevaluated at every call to `mult`. - update_adjoint (bool): Whether to update the adjoint as well as the forward - model when updating the Mat. If True then `rf.derivative` will be called by - the `update` method. Required for Hessian but not for TLM or Adjoint Mats. needs_functional_interface: Whether to create a PETScVecInterface for rf.functional. comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. """ @@ -199,7 +201,6 @@ class ReducedFunctionalMatBase: def __init__(self, rf, *, apply_riesz=False, appctx=None, always_update_tape=False, - update_adjoint=False, needs_functional_interface=False, comm=PETSc.COMM_WORLD): comm = valid_comm(comm) @@ -215,7 +216,6 @@ def __init__(self, rf, *, self.functional_interface = PETScVecInterface( rf.functional, comm=comm) - self.update_adjoint = update_adjoint self._m = new_control_variable(rf) self._shift = 0 self.always_update_tape = always_update_tape @@ -224,7 +224,7 @@ def __init__(self, rf, *, def update(cls, obj, x, A, P): ctx = A.getPythonContext() ctx.control_interface.from_petsc(x, ctx._m) - ctx.update_tape_values(update_adjoint=ctx.update_adjoint) + ctx.update_tape_values(update_adjoint=cls.update_adjoint()) ctx._shift = 0 def shift(self, A, alpha): @@ -279,6 +279,14 @@ def multHermitian_impl(self, A, y): raise NotImplementedError( "Must provide implementation of the Hermitian action of this matrix on an OverloadedType") + @classmethod + def update_adjoint(self): + """ + Whether to update the adjoint as well as the forward model when updating + the Mat. If True then `rf.derivative` will be called by the `update` method. + """ + raise NotImplementedError + class ReducedFunctionalHessianMat(ReducedFunctionalMatBase): """ @@ -302,13 +310,17 @@ def __init__(self, rf, *, apply_riesz=False, appctx=None, always_update_tape=False, comm=PETSc.COMM_WORLD): super().__init__(rf, apply_riesz=apply_riesz, appctx=appctx, - needs_functional_interface=False, update_adjoint=True, + needs_functional_interface=False, always_update_tape=always_update_tape, comm=comm) self.xinterface = self.control_interface self.yinterface = self.control_interface self.x = new_control_variable(rf) + @classmethod + def update_adjoint(self): + return True + def mult_impl(self, A, x): if self.always_update_tape: self.update_tape_values(update_adjoint=True) @@ -341,13 +353,17 @@ def __init__(self, rf, *, apply_riesz=False, appctx=None, always_update_tape=False, comm=PETSc.COMM_WORLD): super().__init__(rf, apply_riesz=apply_riesz, appctx=appctx, - needs_functional_interface=True, update_adjoint=False, + needs_functional_interface=True, always_update_tape=always_update_tape, comm=comm) self.xinterface = self.functional_interface self.yinterface = self.control_interface self.x = rf.functional._ad_copy() + @classmethod + def update_adjoint(self): + return False + def mult_impl(self, A, x): if self.always_update_tape: self.update_tape_values(update_adjoint=False) @@ -377,15 +393,17 @@ class ReducedFunctionalTLMMat(ReducedFunctionalMatBase): """ def __init__(self, rf, *, appctx=None, always_update_tape=False, comm=PETSc.COMM_WORLD): - - super().__init__(rf, appctx=appctx, update_adjoint=False, - needs_functional_interface=True, + super().__init__(rf, appctx=appctx, needs_functional_interface=True, always_update_tape=always_update_tape, comm=comm) self.xinterface = self.control_interface self.yinterface = self.functional_interface self.x = new_control_variable(rf) + @classmethod + def update_adjoint(self): + return False + def mult_impl(self, A, x): if self.always_update_tape: self.update_tape_values(update_adjoint=False)