Skip to content

Commit 148dd67

Browse files
benegeeranocha
andauthored
Use Base.min / Base.max in MPI reductions (#2054)
* use Base.min/max in MPI.Allreduce MPI.jl's reduce currently does not work for custom operators (such as Trixi's min/max) on ARM * add comments * explain workdaround * typo * Apply suggestions from code review Co-authored-by: Hendrik Ranocha <[email protected]> * switch to macos-latest in mpi tests * remove arch specification for macos-latest macos-latest is 14, which is ARM * readd arch, required by julia-actions/setup-julia * back to macos-13 and x64 --------- Co-authored-by: Hendrik Ranocha <[email protected]>
1 parent 56d5420 commit 148dd67

File tree

6 files changed

+31
-13
lines changed

6 files changed

+31
-13
lines changed

src/auxiliary/math.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,11 @@ end
284284
# when using `@fastmath`, which we also get from
285285
# [Fortran](https://godbolt.org/z/Yrsa1js7P)
286286
# or [C++](https://godbolt.org/z/674G7Pccv).
287+
#
288+
# Note however that such a custom reimplementation can cause incompatibilities with other
289+
# packages. Currently we are affected by an issue with MPI.jl on ARM, see
290+
# https://github.com/trixi-framework/Trixi.jl/issues/1922
291+
# The workaround is to resort to Base.min / Base.max when using MPI reductions.
287292
"""
288293
Trixi.max(x, y, ...)
289294

src/callbacks_step/analysis.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,8 @@ function (analysis_callback::AnalysisCallback)(io, du, u, u_ode, t, semi)
434434
res = maximum(abs, view(du, v, ..))
435435
if mpi_isparallel()
436436
# TODO: Debugging, here is a type instability
437-
global_res = MPI.Reduce!(Ref(res), max, mpi_root(), mpi_comm())
437+
# Base.max instead of max needed, see comment in src/auxiliary/math.jl
438+
global_res = MPI.Reduce!(Ref(res), Base.max, mpi_root(), mpi_comm())
438439
if mpi_isroot()
439440
res::eltype(du) = global_res[]
440441
end

src/callbacks_step/analysis_dg2d_parallel.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ function calc_error_norms(func, u, t, analyzer,
131131
global_l2_error = Vector(l2_error)
132132
global_linf_error = Vector(linf_error)
133133
MPI.Reduce!(global_l2_error, +, mpi_root(), mpi_comm())
134-
MPI.Reduce!(global_linf_error, max, mpi_root(), mpi_comm())
134+
# Base.max instead of max needed, see comment in src/auxiliary/math.jl
135+
MPI.Reduce!(global_linf_error, Base.max, mpi_root(), mpi_comm())
135136
total_volume = MPI.Reduce(volume, +, mpi_root(), mpi_comm())
136137
if mpi_isroot()
137138
l2_error = convert(typeof(l2_error), global_l2_error)

src/callbacks_step/analysis_dg3d_parallel.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ function calc_error_norms(func, u, t, analyzer,
4949
global_l2_error = Vector(l2_error)
5050
global_linf_error = Vector(linf_error)
5151
MPI.Reduce!(global_l2_error, +, mpi_root(), mpi_comm())
52-
MPI.Reduce!(global_linf_error, max, mpi_root(), mpi_comm())
52+
# Base.max instead of max needed, see comment in src/auxiliary/math.jl
53+
MPI.Reduce!(global_linf_error, Base.max, mpi_root(), mpi_comm())
5354
total_volume = MPI.Reduce(volume, +, mpi_root(), mpi_comm())
5455
if mpi_isroot()
5556
l2_error = convert(typeof(l2_error), global_l2_error)

src/callbacks_step/stepsize_dg2d.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ function max_dt(u, t, mesh::ParallelTreeMesh{2},
5454
typeof(constant_speed), typeof(equations), typeof(dg),
5555
typeof(cache)},
5656
u, t, mesh, constant_speed, equations, dg, cache)
57-
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
57+
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
58+
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]
5859

5960
return dt
6061
end
@@ -70,7 +71,8 @@ function max_dt(u, t, mesh::ParallelTreeMesh{2},
7071
typeof(constant_speed), typeof(equations), typeof(dg),
7172
typeof(cache)},
7273
u, t, mesh, constant_speed, equations, dg, cache)
73-
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
74+
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
75+
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]
7476

7577
return dt
7678
end
@@ -154,7 +156,8 @@ function max_dt(u, t, mesh::ParallelP4estMesh{2},
154156
typeof(constant_speed), typeof(equations), typeof(dg),
155157
typeof(cache)},
156158
u, t, mesh, constant_speed, equations, dg, cache)
157-
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
159+
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
160+
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]
158161

159162
return dt
160163
end
@@ -170,7 +173,8 @@ function max_dt(u, t, mesh::ParallelP4estMesh{2},
170173
typeof(constant_speed), typeof(equations), typeof(dg),
171174
typeof(cache)},
172175
u, t, mesh, constant_speed, equations, dg, cache)
173-
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
176+
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
177+
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]
174178

175179
return dt
176180
end
@@ -186,7 +190,8 @@ function max_dt(u, t, mesh::ParallelT8codeMesh{2},
186190
typeof(constant_speed), typeof(equations), typeof(dg),
187191
typeof(cache)},
188192
u, t, mesh, constant_speed, equations, dg, cache)
189-
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
193+
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
194+
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]
190195

191196
return dt
192197
end
@@ -202,7 +207,8 @@ function max_dt(u, t, mesh::ParallelT8codeMesh{2},
202207
typeof(constant_speed), typeof(equations), typeof(dg),
203208
typeof(cache)},
204209
u, t, mesh, constant_speed, equations, dg, cache)
205-
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
210+
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
211+
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]
206212

207213
return dt
208214
end

src/callbacks_step/stepsize_dg3d.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ function max_dt(u, t, mesh::ParallelP4estMesh{3},
130130
typeof(constant_speed), typeof(equations), typeof(dg),
131131
typeof(cache)},
132132
u, t, mesh, constant_speed, equations, dg, cache)
133-
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
133+
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
134+
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]
134135

135136
return dt
136137
end
@@ -146,7 +147,8 @@ function max_dt(u, t, mesh::ParallelP4estMesh{3},
146147
typeof(constant_speed), typeof(equations), typeof(dg),
147148
typeof(cache)},
148149
u, t, mesh, constant_speed, equations, dg, cache)
149-
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
150+
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
151+
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]
150152

151153
return dt
152154
end
@@ -162,7 +164,8 @@ function max_dt(u, t, mesh::ParallelT8codeMesh{3},
162164
typeof(constant_speed), typeof(equations), typeof(dg),
163165
typeof(cache)},
164166
u, t, mesh, constant_speed, equations, dg, cache)
165-
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
167+
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
168+
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]
166169

167170
return dt
168171
end
@@ -178,7 +181,8 @@ function max_dt(u, t, mesh::ParallelT8codeMesh{3},
178181
typeof(constant_speed), typeof(equations), typeof(dg),
179182
typeof(cache)},
180183
u, t, mesh, constant_speed, equations, dg, cache)
181-
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
184+
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
185+
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]
182186

183187
return dt
184188
end

0 commit comments

Comments
 (0)