From ef118d69451dcc50e2a100f1049a2f6f68835618 Mon Sep 17 00:00:00 2001 From: Joachim Brand Date: Tue, 28 Feb 2023 00:00:02 +1300 Subject: [PATCH 1/9] avoid reductions with custom operators for MPI --- src/RMPI/helpers.jl | 3 ++- src/RMPI/mpidata.jl | 6 ++++-- test/RMPI.jl | 34 +++++++++++++++++++--------------- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/RMPI/helpers.jl b/src/RMPI/helpers.jl index 37a8f61cf..e7f2804c4 100644 --- a/src/RMPI/helpers.jl +++ b/src/RMPI/helpers.jl @@ -88,7 +88,8 @@ end function sort_into_targets!(dtarget::MPIData, w::AbstractDVec, stats) # single threaded MPI version mpi_combine_walkers!(dtarget,w) # combine walkers from different MPI ranks - res_stats = MPI.Allreduce(Rimu.MultiScalar(stats), +, dtarget.comm) + res_stats = (MPI.Allreduce(stat, +, dtarget.comm) for stat in stats) + # res_stats = MPI.Allreduce(Rimu.MultiScalar(stats), +, dtarget.comm) return dtarget, w, res_stats end diff --git a/src/RMPI/mpidata.jl b/src/RMPI/mpidata.jl index 39bc38e51..0121adba4 100644 --- a/src/RMPI/mpidata.jl +++ b/src/RMPI/mpidata.jl @@ -118,9 +118,11 @@ MPI syncronizing. """ function LinearAlgebra.norm(md::MPIData, p::Real=2) if p === 2 - return sqrt(sum(abs2, values(md))) + return sqrt(mapreduce(abs2, +, values(md))) + # return sqrt(sum(abs2, values(md))) elseif p === 1 - return float(sum(abs, values(md))) + return float(mapreduce(abs, +, values(md))) + # return float(sum(abs, values(md))) elseif p === Inf return float(mapreduce(abs, max, values(md); init=real(zero(valtype(md))))) else diff --git a/test/RMPI.jl b/test/RMPI.jl index da74d697c..abb5b99c4 100644 --- a/test/RMPI.jl +++ b/test/RMPI.jl @@ -6,17 +6,17 @@ using Test @testset "DistributeStrategies" begin # `DistributeStrategy`s - ham = HubbardReal1D(BoseFS((1,2,3))) + ham = HubbardReal1D(BoseFS((1, 2, 3))) for setup in [RMPI.mpi_no_exchange, RMPI.mpi_all_to_all, RMPI.mpi_point_to_point] - dv = DVec(starting_address(ham)=>10; style=IsDynamicSemistochastic()) + dv = DVec(starting_address(ham) => 10; style=IsDynamicSemistochastic()) v = MPIData(dv; setup) - df, state = lomc!(ham,v) + df, state = lomc!(ham, v) @test size(df) == (100, 12) end # need to do mpi_one_sided separately - dv = DVec(starting_address(ham)=>10; style=IsDynamicSemistochastic()) - v = RMPI.mpi_one_sided(dv; capacity = 1000) - df, state = lomc!(ham,v) + dv = DVec(starting_address(ham) => 10; style=IsDynamicSemistochastic()) + v = RMPI.mpi_one_sided(dv; capacity=1000) + df, state = lomc!(ham, v) @test size(df) == (100, 12) end @@ -29,13 +29,13 @@ end counts = zeros(Int, k) displs = zeros(Int, k) - RMPI.sort_and_count!(counts, displs, vals, ordfun.(vals), (0, k-1)) + RMPI.sort_and_count!(counts, displs, vals, ordfun.(vals), (0, k - 1)) @test issorted(vals, by=ordfun) @test sum(counts) == l - for i in 0:(k - 1) - c = counts[i + 1] - d = displs[i + 1] + for i in 0:(k-1) + c = counts[i+1] + d = displs[i+1] r = (1:c) .+ d ords = ordfun.(vals) @test all(ords[r] .== i) @@ -52,10 +52,14 @@ end @testset "Iteration and reductions" begin @test sort(collect(localpart(values(dv1)))) == 1:4 - @test sum(first, pairs(dv1)) == 10 - @test sum(last, pairs(dv1)) == 10 - @test prod(keys(dv1)) == 24 - @test sum(values(dv2)) == 0 + @test mapreduce(first, +, pairs(dv1)) == 10 + # @test sum(first, pairs(dv1)) == 10 + @test mapreduce(last, +, pairs(dv1)) == 10 + # @test sum(last, pairs(dv1)) == 10 + @test reduce(*, keys(dv1)) == 24 + # @test prod(keys(dv1)) == 24 + @test reduce(+, values(dv2)) == 0 + # @test sum(values(dv2)) == 0 end @testset "Errors" begin @test_throws ErrorException [p for p in pairs(dv1)] @@ -79,7 +83,7 @@ end @testset "dot" begin @test dot(dv1, dv2) == 0 @test dot(dv1, dv1) == dot(localpart(dv1), dv1) - rand_ham = MatrixHamiltonian(rand(ComplexF64, 4,4)) + rand_ham = MatrixHamiltonian(rand(ComplexF64, 4, 4)) ldv1 = localpart(dv1) @test norm(dot(dv1, rand_ham, dv1)) ≈ norm(dot(ldv1, rand_ham, ldv1)) end From 7b4db32ef2266c7fc00d78cabb820dd9c9eec366 Mon Sep 17 00:00:00 2001 From: Joachim Brand Date: Fri, 3 Mar 2023 21:10:16 +1300 Subject: [PATCH 2/9] avoid MultiScalar reduction only on non-intel --- src/RMPI/helpers.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/RMPI/helpers.jl b/src/RMPI/helpers.jl index e7f2804c4..62e353ff6 100644 --- a/src/RMPI/helpers.jl +++ b/src/RMPI/helpers.jl @@ -87,9 +87,17 @@ end function sort_into_targets!(dtarget::MPIData, w::AbstractDVec, stats) # single threaded MPI version - mpi_combine_walkers!(dtarget,w) # combine walkers from different MPI ranks - res_stats = (MPI.Allreduce(stat, +, dtarget.comm) for stat in stats) - # res_stats = MPI.Allreduce(Rimu.MultiScalar(stats), +, dtarget.comm) + mpi_combine_walkers!(dtarget, w) # combine walkers from different MPI ranks + @static if Sys.ARCH ∈ (:aarch64, :ppc64le, :powerpc64le) || + startswith(lowercase(String(Sys.ARCH)), "arm") + # Reductions of a custom type (`MultiScalar`) are not possible with MPI.jl on + # non-intel architectures at the moment + # see https://github.com/JuliaParallel/MPI.jl/issues/404 + res_stats = (MPI.Allreduce(stat, +, dtarget.comm) for stat in stats) + else + # this should be more efficient if it is allowed + res_stats = MPI.Allreduce(Rimu.MultiScalar(stats), +, dtarget.comm) + end return dtarget, w, res_stats end From 6d2f7bc7840a0ee49ef56f1b773ba52710b376ea Mon Sep 17 00:00:00 2001 From: Joachim Brand Date: Fri, 3 Mar 2023 21:25:27 +1300 Subject: [PATCH 3/9] move changes to mapreduce --- src/RMPI/helpers.jl | 2 +- src/RMPI/mpidata.jl | 20 ++++++++++++++++---- test/RMPI.jl | 12 ++++-------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/RMPI/helpers.jl b/src/RMPI/helpers.jl index 62e353ff6..89b5f18a6 100644 --- a/src/RMPI/helpers.jl +++ b/src/RMPI/helpers.jl @@ -91,7 +91,7 @@ function sort_into_targets!(dtarget::MPIData, w::AbstractDVec, stats) @static if Sys.ARCH ∈ (:aarch64, :ppc64le, :powerpc64le) || startswith(lowercase(String(Sys.ARCH)), "arm") # Reductions of a custom type (`MultiScalar`) are not possible with MPI.jl on - # non-intel architectures at the moment + # non-Intel architectures at the moment # see https://github.com/JuliaParallel/MPI.jl/issues/404 res_stats = (MPI.Allreduce(stat, +, dtarget.comm) for stat in stats) else diff --git a/src/RMPI/mpidata.jl b/src/RMPI/mpidata.jl index 0121adba4..a81a05c87 100644 --- a/src/RMPI/mpidata.jl +++ b/src/RMPI/mpidata.jl @@ -95,6 +95,20 @@ function Base.mapreduce(f, op, it::MPIDataIterator; kwargs...) return MPI.Allreduce(res, op, it.data.comm) end +# Special case for `sum`, which uses a custom (type-widening) reduction operator `add_sum`. +# Replacing it by `+` is necessary for non-Intel architectures due to a limitation of +# MPI.jl. On Intel processors, it might be more perfomant. +# see https://github.com/JuliaParallel/MPI.jl/issues/404 +function Base.mapreduce(f, ::typeof(Base.add_sum), it::MPIDataIterator; kwargs...) + res = mapreduce(f, +, it.iter; kwargs...) + return MPI.Allreduce(res, +, it.data.comm) +end + +function Base.mapreduce(f, ::typeof(Base.mul_prod), it::MPIDataIterator; kwargs...) + res = mapreduce(f, *, it.iter; kwargs...) + return MPI.Allreduce(res, *, it.data.comm) +end + Base.IteratorSize(::MPIDataIterator) = Base.SizeUnknown() Base.pairs(data::MPIData) = MPIDataIterator(pairs(localpart(data)), data) Base.keys(data::MPIData) = MPIDataIterator(keys(localpart(data)), data) @@ -118,11 +132,9 @@ MPI syncronizing. """ function LinearAlgebra.norm(md::MPIData, p::Real=2) if p === 2 - return sqrt(mapreduce(abs2, +, values(md))) - # return sqrt(sum(abs2, values(md))) + return sqrt(sum(abs2, values(md))) elseif p === 1 - return float(mapreduce(abs, +, values(md))) - # return float(sum(abs, values(md))) + return float(sum(abs, values(md))) elseif p === Inf return float(mapreduce(abs, max, values(md); init=real(zero(valtype(md))))) else diff --git a/test/RMPI.jl b/test/RMPI.jl index abb5b99c4..0b3e886f9 100644 --- a/test/RMPI.jl +++ b/test/RMPI.jl @@ -52,14 +52,10 @@ end @testset "Iteration and reductions" begin @test sort(collect(localpart(values(dv1)))) == 1:4 - @test mapreduce(first, +, pairs(dv1)) == 10 - # @test sum(first, pairs(dv1)) == 10 - @test mapreduce(last, +, pairs(dv1)) == 10 - # @test sum(last, pairs(dv1)) == 10 - @test reduce(*, keys(dv1)) == 24 - # @test prod(keys(dv1)) == 24 - @test reduce(+, values(dv2)) == 0 - # @test sum(values(dv2)) == 0 + @test sum(first, pairs(dv1)) == 10 + @test sum(last, pairs(dv1)) == 10 + @test prod(keys(dv1)) == 24 + @test sum(values(dv2)) == 0 end @testset "Errors" begin @test_throws ErrorException [p for p in pairs(dv1)] From 2f457eb92418cc4fdb7191c08d5eceb40767d1d1 Mon Sep 17 00:00:00 2001 From: Joachim Brand Date: Fri, 3 Mar 2023 22:17:01 +1300 Subject: [PATCH 4/9] small tweak --- src/RMPI/mpidata.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/RMPI/mpidata.jl b/src/RMPI/mpidata.jl index a81a05c87..21192a7d5 100644 --- a/src/RMPI/mpidata.jl +++ b/src/RMPI/mpidata.jl @@ -99,13 +99,13 @@ end # Replacing it by `+` is necessary for non-Intel architectures due to a limitation of # MPI.jl. On Intel processors, it might be more perfomant. # see https://github.com/JuliaParallel/MPI.jl/issues/404 -function Base.mapreduce(f, ::typeof(Base.add_sum), it::MPIDataIterator; kwargs...) - res = mapreduce(f, +, it.iter; kwargs...) +function Base.mapreduce(f, op::typeof(Base.add_sum), it::MPIDataIterator; kwargs...) + res = mapreduce(f, op, it.iter; kwargs...) return MPI.Allreduce(res, +, it.data.comm) end -function Base.mapreduce(f, ::typeof(Base.mul_prod), it::MPIDataIterator; kwargs...) - res = mapreduce(f, *, it.iter; kwargs...) +function Base.mapreduce(f, op::typeof(Base.mul_prod), it::MPIDataIterator; kwargs...) + res = mapreduce(f, op, it.iter; kwargs...) return MPI.Allreduce(res, *, it.data.comm) end From 0d8511389d84ca5d00a04858d1dd9742e819e049 Mon Sep 17 00:00:00 2001 From: Joachim Brand Date: Sat, 4 Mar 2023 01:13:07 +1300 Subject: [PATCH 5/9] fix mpi errors --- src/RMPI/helpers.jl | 11 +---------- src/RMPI/mpidata.jl | 15 ++++++++++++++- src/helpers.jl | 1 + test/mpi_runtests.jl | 13 +++++++------ 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/RMPI/helpers.jl b/src/RMPI/helpers.jl index 89b5f18a6..255e452e0 100644 --- a/src/RMPI/helpers.jl +++ b/src/RMPI/helpers.jl @@ -88,16 +88,7 @@ end function sort_into_targets!(dtarget::MPIData, w::AbstractDVec, stats) # single threaded MPI version mpi_combine_walkers!(dtarget, w) # combine walkers from different MPI ranks - @static if Sys.ARCH ∈ (:aarch64, :ppc64le, :powerpc64le) || - startswith(lowercase(String(Sys.ARCH)), "arm") - # Reductions of a custom type (`MultiScalar`) are not possible with MPI.jl on - # non-Intel architectures at the moment - # see https://github.com/JuliaParallel/MPI.jl/issues/404 - res_stats = (MPI.Allreduce(stat, +, dtarget.comm) for stat in stats) - else - # this should be more efficient if it is allowed - res_stats = MPI.Allreduce(Rimu.MultiScalar(stats), +, dtarget.comm) - end + res_stats = MPI.Allreduce(Rimu.MultiScalar(stats), +, dtarget.comm) return dtarget, w, res_stats end diff --git a/src/RMPI/mpidata.jl b/src/RMPI/mpidata.jl index 21192a7d5..8875f580c 100644 --- a/src/RMPI/mpidata.jl +++ b/src/RMPI/mpidata.jl @@ -92,7 +92,13 @@ end function Base.mapreduce(f, op, it::MPIDataIterator; kwargs...) res = mapreduce(f, op, it.iter; kwargs...) - return MPI.Allreduce(res, op, it.data.comm) + # println("typeof(op): ",typeof(op)) + # println("typeof(res): ",typeof(res)) + T = typeof(res) + if T <: Bool # MPI.jl does not support Bool reductions + res = convert(UInt8, res) + end + return T(MPI.Allreduce(res, op, it.data.comm)) end # Special case for `sum`, which uses a custom (type-widening) reduction operator `add_sum`. @@ -368,3 +374,10 @@ function Rimu.all_overlaps(operators::Tuple, vecs::NTuple{N,MPIData}, ::Val{B}) num_reports = (N * (N - 1) ÷ 2) * (B + length(operators)) return Tuple(SVector{num_reports,String}(names)), Tuple(SVector{num_reports,T}(values)) end + +# This is a hack to get MultiScalar reductions to work on Apple Silicon. +# It is taking a detour via Vector. +function MPI.Allreduce(ms::Rimu.MultiScalar{T}, op, comm::MPI.Comm) where {T<:Tuple} + result_vector = MPI.Allreduce([ms...], op, comm) + return Rimu.MultiScalar(T(result_vector)) +end diff --git a/src/helpers.jl b/src/helpers.jl index ced7eb2cd..ae37a639d 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -34,6 +34,7 @@ end MultiScalar(args...) = MultiScalar(args) MultiScalar(v::SVector) = MultiScalar(Tuple(v)) MultiScalar(m::MultiScalar) = m +MultiScalar{T}(m::MultiScalar{T}) where T<:Tuple = m MultiScalar(arg) = MultiScalar((arg,)) Base.getindex(m::MultiScalar, i) = m.tuple[i] diff --git a/test/mpi_runtests.jl b/test/mpi_runtests.jl index f0929aeb1..8a9cb3497 100644 --- a/test/mpi_runtests.jl +++ b/test/mpi_runtests.jl @@ -71,7 +71,7 @@ end end @testset "Single component $type" begin for i in 1:N_REPEATS - add = BoseFS((0,0,10,0,0)) + add = BoseFS((0, 0, 10, 0, 0)) H = HubbardMom1D(add) Random.seed!(7350 * i) v, dv = setup_dv( @@ -98,7 +98,7 @@ end @test sum(values(v)) ≈ sum(values(dv)) f((k, v)) = (k == add) + v > 0 @test mapreduce(f, |, pairs(v); init=true) == - mapreduce(f, |, pairs(dv); init=true) + mapreduce(f, |, pairs(dv); init=true) end @testset "Operations" begin @@ -127,7 +127,7 @@ end end @testset "Two-component $type" begin for i in 1:N_REPEATS - add = BoseFS2C((0,0,10,0,0), (0,0,2,0,0)) + add = BoseFS2C((0, 0, 10, 0, 0), (0, 0, 2, 0, 0)) H = BoseHubbardMom1D2C(add) Random.seed!(7350 * i) v, dv = setup_dv( @@ -225,7 +225,7 @@ end (RMPI.mpi_one_sided, (; capacity=1000)), ) @testset "Regular with $setup and post-steps" begin - H = HubbardReal1D(BoseFS((1,1,1,1,1,1,1)); u=6.0) + H = HubbardReal1D(BoseFS((1, 1, 1, 1, 1, 1, 1)); u=6.0) dv = MPIData( DVec(starting_address(H) => 3; style=IsDynamicSemistochastic()); setup, @@ -253,7 +253,7 @@ end @test all(0 .≤ df.loneliness .≤ 1) end @testset "Initiator with $setup" begin - H = HubbardMom1D(BoseFS((0,0,0,7,0,0,0)); u=6.0) + H = HubbardMom1D(BoseFS((0, 0, 0, 7, 0, 0, 0)); u=6.0) dv = MPIData( InitiatorDVec(starting_address(H) => 3); setup, @@ -295,7 +295,8 @@ end # Make sure all ranks came this far. @testset "Finish" begin - @test MPI.Allreduce(true, &, mpi_comm()) + @test MPI.Allreduce(0x01, &, mpi_comm()) == 0x01 # 0x01 for true + # @test MPI.Allreduce(true, &, mpi_comm()) end end From c563bbb5308e7d6853f39fc1b18b4ecdf2c3f1b9 Mon Sep 17 00:00:00 2001 From: Joachim Brand Date: Sun, 5 Mar 2023 13:44:46 +1300 Subject: [PATCH 6/9] clean up --- src/RMPI/RMPI.jl | 1 + src/RMPI/mpidata.jl | 10 +--------- src/RMPI/multiscalar.jl | 10 ++++++++++ 3 files changed, 12 insertions(+), 9 deletions(-) create mode 100644 src/RMPI/multiscalar.jl diff --git a/src/RMPI/RMPI.jl b/src/RMPI/RMPI.jl index fe65f2503..2234b258b 100644 --- a/src/RMPI/RMPI.jl +++ b/src/RMPI/RMPI.jl @@ -31,6 +31,7 @@ const mpi_registry = Dict{Int,Any}() abstract type DistributeStrategy end include("mpidata.jl") +include("multiscalar.jl") include("helpers.jl") include("noexchange.jl") include("pointtopoint.jl") diff --git a/src/RMPI/mpidata.jl b/src/RMPI/mpidata.jl index 8875f580c..eb2e7a876 100644 --- a/src/RMPI/mpidata.jl +++ b/src/RMPI/mpidata.jl @@ -92,8 +92,6 @@ end function Base.mapreduce(f, op, it::MPIDataIterator; kwargs...) res = mapreduce(f, op, it.iter; kwargs...) - # println("typeof(op): ",typeof(op)) - # println("typeof(res): ",typeof(res)) T = typeof(res) if T <: Bool # MPI.jl does not support Bool reductions res = convert(UInt8, res) @@ -110,6 +108,7 @@ function Base.mapreduce(f, op::typeof(Base.add_sum), it::MPIDataIterator; kwargs return MPI.Allreduce(res, +, it.data.comm) end +# Special case for `prod`, which uses a custom (type-widening) reduction operator `mul_prod` function Base.mapreduce(f, op::typeof(Base.mul_prod), it::MPIDataIterator; kwargs...) res = mapreduce(f, op, it.iter; kwargs...) return MPI.Allreduce(res, *, it.data.comm) @@ -374,10 +373,3 @@ function Rimu.all_overlaps(operators::Tuple, vecs::NTuple{N,MPIData}, ::Val{B}) num_reports = (N * (N - 1) ÷ 2) * (B + length(operators)) return Tuple(SVector{num_reports,String}(names)), Tuple(SVector{num_reports,T}(values)) end - -# This is a hack to get MultiScalar reductions to work on Apple Silicon. -# It is taking a detour via Vector. -function MPI.Allreduce(ms::Rimu.MultiScalar{T}, op, comm::MPI.Comm) where {T<:Tuple} - result_vector = MPI.Allreduce([ms...], op, comm) - return Rimu.MultiScalar(T(result_vector)) -end diff --git a/src/RMPI/multiscalar.jl b/src/RMPI/multiscalar.jl new file mode 100644 index 000000000..d4041e6c1 --- /dev/null +++ b/src/RMPI/multiscalar.jl @@ -0,0 +1,10 @@ +# Make MPI reduction of a `MultiScalar` work on non-Intel processors. +# The `MultiScalar` is converted into a vector before sending through MPI.Allreduce. +# Testing shows that this is about the same speed or even a bit faster on Intel processors +# than reducing the MultiScalar directly via a custom reduction operator. +# Defining the method in RMPI is strictly type piracy as MultiScalar belongs to Rimu and +# not to RMPI. Might clean this up later. +function MPI.Allreduce(ms::Rimu.MultiScalar{T}, op, comm::MPI.Comm) where {T<:Tuple} + result_vector = MPI.Allreduce([ms...], op, comm) + return Rimu.MultiScalar(T(result_vector)) +end From 570a34ff5136aaa53a8b43a8bf427575b702ee59 Mon Sep 17 00:00:00 2001 From: Joachim Brand Date: Sun, 5 Mar 2023 22:09:11 +1300 Subject: [PATCH 7/9] use MPI logical operator directly --- test/mpi_runtests.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/mpi_runtests.jl b/test/mpi_runtests.jl index 8a9cb3497..3b32a961a 100644 --- a/test/mpi_runtests.jl +++ b/test/mpi_runtests.jl @@ -295,7 +295,8 @@ end # Make sure all ranks came this far. @testset "Finish" begin - @test MPI.Allreduce(0x01, &, mpi_comm()) == 0x01 # 0x01 for true + # MPI.jl currently doesn't properly map logical operators (MPI v0.20.8) + @test MPI.Allreduce(true, MPI.LAND, mpi_comm()) # @test MPI.Allreduce(true, &, mpi_comm()) end end From b54868fcee688db04b3a14a8fb6b93794df6a7ec Mon Sep 17 00:00:00 2001 From: Joachim Brand Date: Wed, 8 Mar 2023 00:07:14 +1300 Subject: [PATCH 8/9] Remove MultiScalar from sort_into_targets! fciqmc_col! may now return stats as a vector --- src/RMPI/RMPI.jl | 1 - src/RMPI/helpers.jl | 9 +++++++-- src/RMPI/mpidata.jl | 6 +++++- src/StochasticStyles/styles.jl | 4 ++-- src/helpers.jl | 1 + 5 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/RMPI/RMPI.jl b/src/RMPI/RMPI.jl index 2234b258b..fe65f2503 100644 --- a/src/RMPI/RMPI.jl +++ b/src/RMPI/RMPI.jl @@ -31,7 +31,6 @@ const mpi_registry = Dict{Int,Any}() abstract type DistributeStrategy end include("mpidata.jl") -include("multiscalar.jl") include("helpers.jl") include("noexchange.jl") include("pointtopoint.jl") diff --git a/src/RMPI/helpers.jl b/src/RMPI/helpers.jl index 255e452e0..6f8e129a0 100644 --- a/src/RMPI/helpers.jl +++ b/src/RMPI/helpers.jl @@ -85,10 +85,15 @@ function mpi_combine_walkers!(dtarget::MPIData, source::AbstractDVec) mpi_combine_walkers!(ltarget, storage(source), strategy) end -function sort_into_targets!(dtarget::MPIData, w::AbstractDVec, stats) +function sort_into_targets!(dtarget::MPIData, w::AbstractDVec, stats::T) where {T} # single threaded MPI version mpi_combine_walkers!(dtarget, w) # combine walkers from different MPI ranks - res_stats = MPI.Allreduce(Rimu.MultiScalar(stats), +, dtarget.comm) + if T<:Vector + res_stats = MPI.Allreduce(stats, +, dtarget.comm) + else + # temporarily convert to Vector for native MPI reduction + res_stats = T(MPI.Allreduce([stats...], +, dtarget.comm)) + end return dtarget, w, res_stats end diff --git a/src/RMPI/mpidata.jl b/src/RMPI/mpidata.jl index eb2e7a876..73951a5bd 100644 --- a/src/RMPI/mpidata.jl +++ b/src/RMPI/mpidata.jl @@ -93,8 +93,12 @@ end function Base.mapreduce(f, op, it::MPIDataIterator; kwargs...) res = mapreduce(f, op, it.iter; kwargs...) T = typeof(res) - if T <: Bool # MPI.jl does not support Bool reductions + if T <: Bool # MPI.jl does not currently support Bool reductions on ARM + #TODO remove when https://github.com/JuliaParallel/MPI.jl/pull/719 + # is merged and released res = convert(UInt8, res) + elseif T <: Rimu.MultiScalar # MPI.jl does not support MultiScalar reductions on ARM + res = [res...] end return T(MPI.Allreduce(res, op, it.data.comm)) end diff --git a/src/StochasticStyles/styles.jl b/src/StochasticStyles/styles.jl index 4f0c48ebb..815852d77 100644 --- a/src/StochasticStyles/styles.jl +++ b/src/StochasticStyles/styles.jl @@ -46,7 +46,7 @@ function step_stats(::IsStochastic2Pop{T}) where {T} z = zero(T) return ( (:spawns, :deaths, :clones, :zombies), - MultiScalar(z, z, z, z) + [z, z, z, z] ) end function fciqmc_col!(::IsStochastic2Pop, w, ham, add, val, shift, dτ) @@ -61,7 +61,7 @@ function fciqmc_col!(::IsStochastic2Pop, w, ham, add, val, shift, dτ) clones, deaths, zombies = diagonal_step!(w, ham, add, val, dτ, shift, 0) - return (spawns, deaths, clones, zombies) + return [spawns, deaths, clones, zombies] end """ diff --git a/src/helpers.jl b/src/helpers.jl index ae37a639d..0605fc836 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -35,6 +35,7 @@ MultiScalar(args...) = MultiScalar(args) MultiScalar(v::SVector) = MultiScalar(Tuple(v)) MultiScalar(m::MultiScalar) = m MultiScalar{T}(m::MultiScalar{T}) where T<:Tuple = m +MultiScalar{T}(v::Vector) where {T<:Tuple} = MultiScalar{T}(T(v)) MultiScalar(arg) = MultiScalar((arg,)) Base.getindex(m::MultiScalar, i) = m.tuple[i] From cdfafba80c071f0435a72eb2e7849bcfb773bcdd Mon Sep 17 00:00:00 2001 From: Joachim Brand Date: Thu, 9 Mar 2023 17:30:54 +1300 Subject: [PATCH 9/9] Revert "Remove MultiScalar from sort_into_targets!" This reverts commit b54868fcee688db04b3a14a8fb6b93794df6a7ec. --- src/RMPI/RMPI.jl | 1 + src/RMPI/helpers.jl | 9 ++------- src/RMPI/mpidata.jl | 6 +----- src/StochasticStyles/styles.jl | 4 ++-- src/helpers.jl | 1 - 5 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/RMPI/RMPI.jl b/src/RMPI/RMPI.jl index fe65f2503..2234b258b 100644 --- a/src/RMPI/RMPI.jl +++ b/src/RMPI/RMPI.jl @@ -31,6 +31,7 @@ const mpi_registry = Dict{Int,Any}() abstract type DistributeStrategy end include("mpidata.jl") +include("multiscalar.jl") include("helpers.jl") include("noexchange.jl") include("pointtopoint.jl") diff --git a/src/RMPI/helpers.jl b/src/RMPI/helpers.jl index 6f8e129a0..255e452e0 100644 --- a/src/RMPI/helpers.jl +++ b/src/RMPI/helpers.jl @@ -85,15 +85,10 @@ function mpi_combine_walkers!(dtarget::MPIData, source::AbstractDVec) mpi_combine_walkers!(ltarget, storage(source), strategy) end -function sort_into_targets!(dtarget::MPIData, w::AbstractDVec, stats::T) where {T} +function sort_into_targets!(dtarget::MPIData, w::AbstractDVec, stats) # single threaded MPI version mpi_combine_walkers!(dtarget, w) # combine walkers from different MPI ranks - if T<:Vector - res_stats = MPI.Allreduce(stats, +, dtarget.comm) - else - # temporarily convert to Vector for native MPI reduction - res_stats = T(MPI.Allreduce([stats...], +, dtarget.comm)) - end + res_stats = MPI.Allreduce(Rimu.MultiScalar(stats), +, dtarget.comm) return dtarget, w, res_stats end diff --git a/src/RMPI/mpidata.jl b/src/RMPI/mpidata.jl index 73951a5bd..eb2e7a876 100644 --- a/src/RMPI/mpidata.jl +++ b/src/RMPI/mpidata.jl @@ -93,12 +93,8 @@ end function Base.mapreduce(f, op, it::MPIDataIterator; kwargs...) res = mapreduce(f, op, it.iter; kwargs...) T = typeof(res) - if T <: Bool # MPI.jl does not currently support Bool reductions on ARM - #TODO remove when https://github.com/JuliaParallel/MPI.jl/pull/719 - # is merged and released + if T <: Bool # MPI.jl does not support Bool reductions res = convert(UInt8, res) - elseif T <: Rimu.MultiScalar # MPI.jl does not support MultiScalar reductions on ARM - res = [res...] end return T(MPI.Allreduce(res, op, it.data.comm)) end diff --git a/src/StochasticStyles/styles.jl b/src/StochasticStyles/styles.jl index 815852d77..4f0c48ebb 100644 --- a/src/StochasticStyles/styles.jl +++ b/src/StochasticStyles/styles.jl @@ -46,7 +46,7 @@ function step_stats(::IsStochastic2Pop{T}) where {T} z = zero(T) return ( (:spawns, :deaths, :clones, :zombies), - [z, z, z, z] + MultiScalar(z, z, z, z) ) end function fciqmc_col!(::IsStochastic2Pop, w, ham, add, val, shift, dτ) @@ -61,7 +61,7 @@ function fciqmc_col!(::IsStochastic2Pop, w, ham, add, val, shift, dτ) clones, deaths, zombies = diagonal_step!(w, ham, add, val, dτ, shift, 0) - return [spawns, deaths, clones, zombies] + return (spawns, deaths, clones, zombies) end """ diff --git a/src/helpers.jl b/src/helpers.jl index 0605fc836..ae37a639d 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -35,7 +35,6 @@ MultiScalar(args...) = MultiScalar(args) MultiScalar(v::SVector) = MultiScalar(Tuple(v)) MultiScalar(m::MultiScalar) = m MultiScalar{T}(m::MultiScalar{T}) where T<:Tuple = m -MultiScalar{T}(v::Vector) where {T<:Tuple} = MultiScalar{T}(T(v)) MultiScalar(arg) = MultiScalar((arg,)) Base.getindex(m::MultiScalar, i) = m.tuple[i]