From aad402ffaa9093e2dfb1df45737c35307ecf3b45 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 10 Dec 2024 10:07:17 +0100 Subject: [PATCH 1/3] let's see --- ext/MPIExt/MPIExt.jl | 4 +++- ext/MPIExt/evaluate.jl | 10 +++++----- ext/MPIExt/generate_counterfactual.jl | 10 +++++----- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/ext/MPIExt/MPIExt.jl b/ext/MPIExt/MPIExt.jl index 2668d79..1066cf0 100644 --- a/ext/MPIExt/MPIExt.jl +++ b/ext/MPIExt/MPIExt.jl @@ -15,6 +15,7 @@ struct MPIParallelizer <: TaijaParallel.AbstractParallelizer n_proc::Int n_each::Union{Nothing,Int} threaded::Bool + active_comm::MPI.Comm end """ @@ -26,6 +27,7 @@ function TaijaParallel.MPIParallelizer( comm::MPI.Comm; n_each::Union{Nothing,Int} = nothing, threaded::Bool = false, + active_comm::MPI.Comm = comm ) rank = MPI.Comm_rank(comm) # Rank of this process in the world 🌍 n_proc = MPI.Comm_size(comm) # Number of processes in the world 🌍 @@ -42,7 +44,7 @@ function TaijaParallel.MPIParallelizer( end end - return MPIParallelizer(comm, rank, n_proc, n_each, threaded) + return MPIParallelizer(comm, rank, n_proc, n_each, threaded, active_comm) end include("generate_counterfactual.jl") diff --git a/ext/MPIExt/evaluate.jl b/ext/MPIExt/evaluate.jl index 1bbd296..fb9051f 100644 --- a/ext/MPIExt/evaluate.jl +++ b/ext/MPIExt/evaluate.jl @@ -73,7 +73,7 @@ function TaijaBase.parallelize( kwargs..., ) end - MPI.Barrier(parallelizer.comm) + MPI.Barrier(parallelizer.active_comm) # Collect output from all processe in rank 0: collected_output = MPI.gather(output, parallelizer.comm) @@ -81,11 +81,11 @@ function TaijaBase.parallelize( output = vcat(collected_output...) Serialization.serialize(joinpath(storage_path, "output_$i.jls"), output) end - MPI.Barrier(parallelizer.comm) + MPI.Barrier(parallelizer.active_comm) end # Collect all chunks in rank 0: - MPI.Barrier(parallelizer.comm) + MPI.Barrier(parallelizer.active_comm) # Load output from rank 0: if parallelizer.rank == 0 @@ -101,8 +101,8 @@ function TaijaBase.parallelize( end # Broadcast output to all processes: - final_output = MPI.bcast(output, parallelizer.comm; root = 0) - MPI.Barrier(parallelizer.comm) + final_output = MPI.bcast(output, parallelizer.active_comm; root = 0) + MPI.Barrier(parallelizer.active_comm) return final_output end diff --git a/ext/MPIExt/generate_counterfactual.jl b/ext/MPIExt/generate_counterfactual.jl index 95fc51b..b1eb015 100644 --- a/ext/MPIExt/generate_counterfactual.jl +++ b/ext/MPIExt/generate_counterfactual.jl @@ -73,7 +73,7 @@ function TaijaBase.parallelize( kwargs..., ) end - MPI.Barrier(parallelizer.comm) + MPI.Barrier(parallelizer.active_comm) # Collect output from all processe in rank 0: collected_output = MPI.gather(output, parallelizer.comm) @@ -81,11 +81,11 @@ function TaijaBase.parallelize( output = vcat(collected_output...) Serialization.serialize(joinpath(storage_path, "output_$i.jls"), output) end - MPI.Barrier(parallelizer.comm) + MPI.Barrier(parallelizer.active_comm) end # Collect all chunks in rank 0: - MPI.Barrier(parallelizer.comm) + MPI.Barrier(parallelizer.active_comm) # Load output from rank 0: if parallelizer.rank == 0 @@ -101,8 +101,8 @@ function TaijaBase.parallelize( end # Broadcast output to all processes: - final_output = MPI.bcast(output, parallelizer.comm; root = 0) - MPI.Barrier(parallelizer.comm) + final_output = MPI.bcast(output, parallelizer.active_comm; root = 0) + MPI.Barrier(parallelizer.active_comm) return final_output end From f5a8e20ec1b70b19ce7ed51c1d07d1b706310d56 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 10 Dec 2024 10:23:26 +0100 Subject: [PATCH 2/3] let's see --- ext/MPIExt/MPIExt.jl | 8 +++++--- ext/MPIExt/comms.jl | 8 ++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 ext/MPIExt/comms.jl diff --git a/ext/MPIExt/MPIExt.jl b/ext/MPIExt/MPIExt.jl index 1066cf0..3160c97 100644 --- a/ext/MPIExt/MPIExt.jl +++ b/ext/MPIExt/MPIExt.jl @@ -27,10 +27,11 @@ function TaijaParallel.MPIParallelizer( comm::MPI.Comm; n_each::Union{Nothing,Int} = nothing, threaded::Bool = false, - active_comm::MPI.Comm = comm + active_comm::Union{Nothing,MPI.Comm} = comm ) - rank = MPI.Comm_rank(comm) # Rank of this process in the world 🌍 - n_proc = MPI.Comm_size(comm) # Number of processes in the world 🌍 + rank = MPI.Comm_rank(comm) # Rank of this process in the world 🌍 + n_proc = MPI.Comm_size(comm) # Number of processes in the world 🌍 + active_comm = isnothing(active_comm) ? comm : active_comm # Active communication channel (if specified) if rank == 0 @info "Using `MPI.jl` for multi-processing." @@ -49,5 +50,6 @@ end include("generate_counterfactual.jl") include("evaluate.jl") +include("comms.jl") end diff --git a/ext/MPIExt/comms.jl b/ext/MPIExt/comms.jl new file mode 100644 index 0000000..a6be9cf --- /dev/null +++ b/ext/MPIExt/comms.jl @@ -0,0 +1,8 @@ +global _active_comm::Union{Nothing,MPI.Comm} = nothing + +function set_active_comm(comm::MPI.Comm) + global _active_comm = comm + return _active_comm +end + +get_active_comm() = _active_comm \ No newline at end of file From 96ea68ed400e5b1c1cc7f664a6a21a3c60b1adfe Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 10 Dec 2024 10:42:08 +0100 Subject: [PATCH 3/3] k --- ext/MPIExt/MPIExt.jl | 1 - ext/MPIExt/comms.jl | 8 -------- src/extensions/MPIExt.jl | 9 +++++++++ 3 files changed, 9 insertions(+), 9 deletions(-) delete mode 100644 ext/MPIExt/comms.jl diff --git a/ext/MPIExt/MPIExt.jl b/ext/MPIExt/MPIExt.jl index 3160c97..8914a61 100644 --- a/ext/MPIExt/MPIExt.jl +++ b/ext/MPIExt/MPIExt.jl @@ -50,6 +50,5 @@ end include("generate_counterfactual.jl") include("evaluate.jl") -include("comms.jl") end diff --git a/ext/MPIExt/comms.jl b/ext/MPIExt/comms.jl deleted file mode 100644 index a6be9cf..0000000 --- a/ext/MPIExt/comms.jl +++ /dev/null @@ -1,8 +0,0 @@ -global _active_comm::Union{Nothing,MPI.Comm} = nothing - -function set_active_comm(comm::MPI.Comm) - global _active_comm = comm - return _active_comm -end - -get_active_comm() = _active_comm \ No newline at end of file diff --git a/src/extensions/MPIExt.jl b/src/extensions/MPIExt.jl index 0973f47..0858eca 100644 --- a/src/extensions/MPIExt.jl +++ b/src/extensions/MPIExt.jl @@ -5,3 +5,12 @@ Exposes the `MPIParallelizer` function from the `MPIExt` extension. """ function MPIParallelizer end export MPIParallelizer + +global _active_comm = nothing + +function set_active_comm(comm) + global _active_comm = comm + return _active_comm +end + +get_active_comm() = _active_comm \ No newline at end of file