@@ -197,50 +197,71 @@ class SubfunctionBlock(Block):
197197    def  __init__ (self , func , idx , ad_block_tag = None ):
198198        super ().__init__ (ad_block_tag = ad_block_tag )
199199        self .add_dependency (func )
200-         self .idx  =  idx 
200+         self .sub_idx  =  idx 
201201
202202    def  evaluate_adj_component (self , inputs , adj_inputs , block_variable , idx ,
203203                               prepared = None ):
204204        eval_adj  =  firedrake .Cofunction (block_variable .output .function_space ().dual ())
205205        if  type (adj_inputs [0 ]) is  firedrake .Cofunction :
206-             eval_adj .sub (self .idx ).assign (adj_inputs [0 ])
206+             eval_adj .sub (self .sub_idx ).assign (adj_inputs [0 ])
207207        else :
208-             eval_adj .sub (self .idx ).assign (adj_inputs [0 ].function )
208+             eval_adj .sub (self .sub_idx ).assign (adj_inputs [0 ].function )
209209        return  eval_adj 
210210
211211    def  evaluate_tlm_component (self , inputs , tlm_inputs , block_variable , idx ,
212212                               prepared = None ):
213-         return  firedrake .Function .sub (tlm_inputs [0 ], self .idx )
213+         return  firedrake .Function .sub (tlm_inputs [0 ], self .sub_idx )
214214
215215    def  evaluate_hessian_component (self , inputs , hessian_inputs , adj_inputs ,
216216                                   block_variable , idx ,
217217                                   relevant_dependencies , prepared = None ):
218218        eval_hessian  =  firedrake .Cofunction (block_variable .output .function_space ().dual ())
219-         eval_hessian .sub (self .idx ).assign (hessian_inputs [0 ])
219+         eval_hessian .sub (self .sub_idx ).assign (hessian_inputs [0 ])
220220        return  eval_hessian 
221221
222222    def  recompute_component (self , inputs , block_variable , idx , prepared ):
223223        return  maybe_disk_checkpoint (
224-             firedrake .Function .sub (inputs [0 ], self .idx )
224+             firedrake .Function .sub (inputs [0 ], self .sub_idx )
225225        )
226226
227227    def  __str__ (self ):
228-         return  f"{ self .get_dependencies ()[0 ]}  [{ self .idx }  ]" 
228+         return  f"{ self .get_dependencies ()[0 ]}  [{ self .sub_idx }  ]" 
229229
230230
231231class  FunctionMergeBlock (Block ):
232232    def  __init__ (self , func , idx , ad_block_tag = None ):
233233        super ().__init__ (ad_block_tag = ad_block_tag )
234234        self .add_dependency (func )
235-         self .idx  =  idx 
235+         self .sub_idx  =  idx 
236236        for  output  in  func ._ad_outputs :
237237            self .add_dependency (output )
238238
239239    def  evaluate_adj_component (self , inputs , adj_inputs , block_variable , idx ,
240240                               prepared = None ):
241+         # The merge block appears whenever a subfunction is the output of a block. 
242+         # This means that the subfunction has been modified, so we need to make 
243+         # sure that this modification is accounted for when evaluating the adjoint. 
244+         # 
245+         # When recomputing the merge block, the indexed subfunction in the full 
246+         # Function is completely overwritten, meaning that the pre-existing value 
247+         # of the subfunction in the full function is ignored. 
248+         # The equivalent adjoint operation is to: 
249+         #   1. send the subfunction component of the adjoint value back up 
250+         #      the branch of the tape corresponding to the subfunction 
251+         #      dependency (idx=0). 
252+         #   2. zero out the subfunction component of the adjoint value sent 
253+         #      back up the full Function branch of the tape (idx=1). 
254+         # This means that when the adjoint values of each branch are combined 
255+         # after the SubfunctionBlock only the adjoint value from the subfunction 
256+         # branch is used. 
257+         # 
258+         # See https://github.com/firedrakeproject/firedrake/pull/4177 for more 
259+         # detail and for diagrams of the tape produced when accessing subfunctions. 
260+ 
241261        if  idx  ==  0 :
242-             return  adj_inputs [0 ].subfunctions [self .idx ] 
262+             return  adj_inputs [0 ].subfunctions [self .sub_idx ]. copy ( deepcopy = True ) 
243263        else :
264+             adj_inputs [0 ].subfunctions [self .sub_idx ].zero ()
244265            return  adj_inputs [0 ]
245266
246267    def  evaluate_tlm (self , markings = False ):
@@ -253,7 +274,7 @@ def evaluate_tlm(self, markings=False):
253274        fs  =  output .output .function_space ()
254275        f  =  type (output .output )(fs )
255276        output .add_tlm_output (
256-             type (output .output ).assign (f .sub (self .idx ), tlm_input )
277+             type (output .output ).assign (f .sub (self .sub_idx ), tlm_input )
257278        )
258279
259280    def  evaluate_hessian_component (self , inputs , hessian_inputs , adj_inputs ,
@@ -265,12 +286,12 @@ def recompute_component(self, inputs, block_variable, idx, prepared):
265286        sub_func  =  inputs [0 ]
266287        parent_in  =  inputs [1 ]
267288        parent_out  =  type (parent_in )(parent_in )
268-         parent_out .sub (self .idx ).assign (sub_func )
289+         parent_out .sub (self .sub_idx ).assign (sub_func )
269290        return  maybe_disk_checkpoint (parent_out )
270291
271292    def  __str__ (self ):
272293        deps  =  self .get_dependencies ()
273-         return  f"{ deps [1 ]}  [{ self .idx }  ].assign({ deps [0 ]}  )" 
294+         return  f"{ deps [1 ]}  [{ self .sub_idx }  ].assign({ deps [0 ]}  )" 
274295
275296
276297class  CofunctionAssignBlock (Block ):
0 commit comments