Skip to content

Commit eeed089

Browse files
authored
Improve Datatype handling (#471)
* Improve Datatype handling - Uses generated functions to avoid creating new instances `Datatype`s on every communication operation (#462). These are lazily initialized to allow for use in precompiled modules. - Attach the corresponding Julia to an MPI Datatype as an attribute - Nicer printing of Datatypes - Allow specifying bits type as return type of MPI.Scatter - Minor optimisation when constructing buffers of a single NTuple * duplicate Datatypes if same size as existing type
1 parent a584b85 commit eeed089

10 files changed

+205
-41
lines changed

deps/consts_microsoftmpi.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ const MPI_LOCK_EXCLUSIVE = Cint(234)
8383
const MPI_LOCK_SHARED = Cint(235)
8484
const MPI_MAX_INFO_KEY = Cint(255)
8585
const MPI_MAX_INFO_VAL = Cint(1024)
86+
const MPI_MAX_OBJECT_NAME = Cint(128)
8687
const MPI_TAG_UB = reinterpret(Cint, 0x64400001)
8788
const MPI_COMM_TYPE_SHARED = Cint(1)
8889
const MPI_ORDER_C = Cint(56)

deps/consts_mpich.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ const MPI_LOCK_EXCLUSIVE = Cint(234)
8888
const MPI_LOCK_SHARED = Cint(235)
8989
const MPI_MAX_INFO_KEY = Cint(255)
9090
const MPI_MAX_INFO_VAL = Cint(1024)
91+
const MPI_MAX_OBJECT_NAME = Cint(128)
9192
const MPI_TAG_UB = Cint(1681915905)
9293
const MPI_COMM_TYPE_SHARED = Cint(1)
9394
const MPI_ORDER_C = Cint(56)

deps/consts_openmpi.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ const MPI_LOCK_EXCLUSIVE = Cint(1)
9595
const MPI_LOCK_SHARED = Cint(2)
9696
const MPI_MAX_INFO_KEY = Cint(36)
9797
const MPI_MAX_INFO_VAL = Cint(256)
98+
const MPI_MAX_OBJECT_NAME = Cint(64)
9899
const MPI_TAG_UB = Cint(0)
99100
const MPI_COMM_TYPE_SHARED = Cint(0)
100101
const MPI_ORDER_C = Cint(0)

deps/gen_consts.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ MPI_Cints = [
126126
:MPI_LOCK_SHARED,
127127
:MPI_MAX_INFO_KEY,
128128
:MPI_MAX_INFO_VAL,
129+
:MPI_MAX_OBJECT_NAME,
129130
:MPI_TAG_UB,
130131
:MPI_COMM_TYPE_SHARED,
131132
:MPI_ORDER_C,

docs/examples/06-scatterv.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ if rank == root
3636

3737
# store sizes in 2 * comm_size Array
3838
sizes = vcat(M_counts', N_counts')
39+
size_ubuf = UBuffer(sizes, 2)
3940

4041
# store number of values to send to each rank in comm_size length Vector
4142
counts = vec(prod(sizes, dims=1))
@@ -44,7 +45,7 @@ if rank == root
4445
output_vbuf = VBuffer(output, counts) # VBuffer for gather
4546
else
4647
# these variables can be set to `nothing` on non-root processes
47-
sizes = nothing
48+
size_ubuf = UBuffer(nothing)
4849
output_vbuf = test_vbuf = VBuffer(nothing)
4950
end
5051

@@ -58,8 +59,8 @@ if rank == root
5859
end
5960
MPI.Barrier(comm)
6061

61-
local_M, local_N = MPI.Scatter!(sizes, zeros(Int, 2), root, comm)
62-
local_test = MPI.Scatterv!(test_vbuf, zeros(Float64, local_M, local_N), root, comm)
62+
local_size = MPI.Scatter(size_ubuf, NTuple{2,Int}, root, comm)
63+
local_test = MPI.Scatterv!(test_vbuf, zeros(Float64, local_size), root, comm)
6364

6465
for i = 0:comm_size-1
6566
if rank == i

docs/src/advanced.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ MPI.Types.create_subarray
1717
MPI.Types.create_struct
1818
MPI.Types.create_resized
1919
MPI.Types.commit!
20+
MPI.Types.duplicate
2021
```
2122

2223
## Operator objects

src/buffers.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,22 +147,29 @@ function Buffer(sub::Base.FastContiguousSubArray)
147147
end
148148
function Buffer(sub::Base.FastSubArray)
149149
datatype = Types.create_vector(length(sub), 1, sub.stride1,
150-
Datatype(eltype(sub); commit=false))
150+
Datatype(eltype(sub)))
151151
Types.commit!(datatype)
152152
Buffer(sub, Cint(1), datatype)
153153
end
154154
function Buffer(sub::SubArray{T,N,P,I,false}) where {T,N,P,I<:Tuple{Vararg{Union{Base.ScalarIndex, Base.Slice, AbstractUnitRange}}}}
155155
datatype = Types.create_subarray(size(parent(sub)),
156156
map(length, sub.indices),
157157
map(i -> first(i)-1, sub.indices),
158-
Datatype(eltype(sub), commit=false))
158+
Datatype(eltype(sub)))
159159
Types.commit!(datatype)
160160
Buffer(parent(sub), Cint(1), datatype)
161161
end
162162

163+
# NTuple: avoid creating a new datatype if possible
164+
function Buffer(data::Ref{NTuple{N,T}}) where {N,T}
165+
Buffer(data, Cint(N), Datatype(T))
166+
end
167+
168+
163169
Buffer(::InPlace) = Buffer(IN_PLACE, 0, DATATYPE_NULL)
164170
Buffer(::Nothing) = Buffer(nothing, 0, DATATYPE_NULL)
165171

172+
166173
"""
167174
Buffer_send(data)
168175

src/collective.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,25 @@ Scatter!(sendbuf::Nothing, recvbuf, root::Integer, comm::Comm) =
113113
Scatter!(UBuffer(nothing), recvbuf, root, comm)
114114

115115
# determine UBuffer count from recvbuf
116-
Scatter!(sendbuf::AbstractArray, recvbuf::Union{Ref,AbstractArray}, root::Integer, comm::Comm) =
116+
Scatter!(sendbuf::AbstractArray{T}, recvbuf::Union{Ref{T},AbstractArray{T}}, root::Integer, comm::Comm) where {T} =
117117
Scatter!(UBuffer(sendbuf,length(recvbuf)), recvbuf, root, comm)
118118

119119
"""
120-
Scatterv!(sendbuf::Union{VBuffer,Nothing}, recvbuf, root, comm)
120+
Scatter(sendbuf, T, root::Integer, comm::Comm)
121+
122+
Splits the buffer `sendbuf` in the `root` process into `Comm_size(comm)` chunks,
123+
sending the `j`-th chunk to the process of rank `j-1` as an object of type `T`.
124+
125+
# See also
126+
- [`Scatter!`](@ref)
127+
"""
128+
function Scatter(sendbuf, ::Type{T}, root::Integer, comm::Comm) where {T}
129+
Scatter!(sendbuf, Ref{T}(), root, comm)[]
130+
end
131+
132+
133+
"""
134+
Scatterv!(sendbuf, T, root, comm)
121135
122136
Splits the buffer `sendbuf` in the `root` process into `Comm_size(comm)` chunks and sends
123137
the `j`th chunk to the process of rank `j-1` into the `recvbuf` buffer.

0 commit comments

Comments
 (0)