@@ -931,9 +931,6 @@ def callable():
931
931
else :
932
932
loops = []
933
933
934
- if access == op2 .INC :
935
- loops .append (tensor .zero )
936
-
937
934
# Arguments in the operand are allowed to be from a MixedFunctionSpace
938
935
# We need to split the target space V and generate separate kernels
939
936
if len (arguments ) == 2 :
@@ -961,12 +958,23 @@ def callable():
961
958
if bcs and rank == 1 :
962
959
loops .extend (partial (bc .apply , f ) for bc in bcs )
963
960
964
- def callable (loops , f ):
965
- for l in loops :
966
- l ()
961
+ def callable (loops , f , access ):
962
+ if access is op2 .WRITE :
963
+ for l in loops :
964
+ l ()
965
+ return f
966
+ # We are repeatedly incrementing into the same Dat so intermediate halo exchanges
967
+ # can be skipped.
968
+ f .dat .local_to_global_begin (access )
969
+ with f .dat .frozen_halo (access ):
970
+ if access is op2 .INC :
971
+ f .dat .zero ()
972
+ for l in loops :
973
+ l ()
974
+ f .dat .local_to_global_end (access )
967
975
return f
968
976
969
- return partial (callable , loops , f )
977
+ return partial (callable , loops , f , access )
970
978
971
979
972
980
@utils .known_pyop2_safe
@@ -1076,7 +1084,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
1076
1084
coefficient_numbers = kernel .coefficient_numbers
1077
1085
needs_external_coords = kernel .needs_external_coords
1078
1086
name = kernel .name
1079
- kernel = op2 .Kernel (ast , name , requires_zeroed_output_arguments = True ,
1087
+ kernel = op2 .Kernel (ast , name , requires_zeroed_output_arguments = ( access is op2 . WRITE ) ,
1080
1088
flop_count = kernel .flop_count , events = (kernel .event ,))
1081
1089
1082
1090
parloop_args = [kernel , cell_set ]
@@ -1099,7 +1107,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
1099
1107
if isinstance (tensor , op2 .Global ):
1100
1108
parloop_args .append (tensor (access ))
1101
1109
elif isinstance (tensor , op2 .Dat ):
1102
- V_dest = arguments [- 1 ].function_space () if isinstance ( dual_arg , ufl . Cofunction ) else V
1110
+ V_dest = arguments [0 ].function_space ()
1103
1111
m_ = get_interp_node_map (source_mesh , target_mesh , V_dest )
1104
1112
parloop_args .append (tensor (access , m_ ))
1105
1113
else :
@@ -1159,11 +1167,10 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
1159
1167
parloop_args .append (target_ref_coords .dat (op2 .READ , m_ ))
1160
1168
1161
1169
parloop = op2 .ParLoop (* parloop_args )
1162
- parloop_compute_callable = parloop .compute
1163
1170
if isinstance (tensor , op2 .Mat ):
1164
- return parloop_compute_callable , tensor .assemble
1171
+ return parloop , tensor .assemble
1165
1172
else :
1166
- return copyin + callables + (parloop_compute_callable , ) + copyout
1173
+ return copyin + callables + (parloop , ) + copyout
1167
1174
1168
1175
1169
1176
def get_interp_node_map (source_mesh , target_mesh , fs ):
0 commit comments