Skip to content

Commit 6140e98

Browse files
committed
different approach: use Ptr to Vector and native +
1 parent c5970f3 commit 6140e98

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ jobs:
116116
- uses: julia-actions/julia-buildpkg@v1
117117
env:
118118
PYTHON: ""
119-
- run: julia --project=@. -e 'using Pkg; pkg"add MPI#vc/custom_ops"'
120119
- name: Run tests without coverage
121120
uses: julia-actions/julia-runtest@v1
122121
with:

src/auxiliary/mpi.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,3 @@ parallel execution of Trixi.jl.
128128
See the "Miscellaneous" section of the [documentation](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/).
129129
"""
130130
ode_unstable_check(dt, u, semi, t) = isnan(dt)
131-
132-
# Custom MPI operators to work around
133-
# https://github.com/trixi-framework/Trixi.jl/issues/1922
134-
function reduce_vector_plus(x, y)
135-
x .+ y
136-
end
137-
MPI.@Op(reduce_vector_plus, SVector{4, Float64})
138-
MPI.@Op(reduce_vector_plus, SVector{5, Float64})

src/callbacks_step/analysis_dg2d_parallel.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,19 @@ function integrate_via_indices(func::Func, u,
162162
normalize = normalize)
163163

164164
# OBS! Global results are only calculated on MPI root, all other domains receive `nothing`
165-
global_integral = MPI.Reduce!(Ref(local_integral), reduce_vector_plus, mpi_root(),
166-
mpi_comm())
165+
if local_integral isa Real
166+
global_integral = MPI.Reduce!(Ref(local_integral), +, mpi_root(), mpi_comm())
167+
else
168+
global_integral = MPI.Reduce!(Base.unsafe_convert(Ptr{Float64}, Ref(local_integral)), +, mpi_root(), mpi_comm())
169+
end
170+
167171
if mpi_isroot()
168-
integral = convert(typeof(local_integral), global_integral[])
172+
if local_integral isa Real
173+
integral = global_integral[]
174+
else
175+
global_wrapped = unsafe_wrap(Array, global_integral, length(local_integral))
176+
integral = convert(typeof(local_integral), global_wrapped)
177+
end
169178
else
170179
integral = convert(typeof(local_integral), NaN * local_integral)
171180
end

0 commit comments

Comments
 (0)