Skip to content

Commit 21949a1

Browse files
committed
allow declaration of custom operations at precompilation
Fixes #404.
1 parent 00a3f56 commit 21949a1

File tree

1 file changed

+54
-5
lines changed

1 file changed

+54
-5
lines changed

src/operators.jl

+54-5
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,23 @@ An MPI reduction operator, for use with [Reduce/Scan collective operations](@ref
1010
Wrap the Julia reduction function `op` for arguments of type `T`. `op` is assumed to be
1111
associative, and if `iscommutative` is true, assumed to be commutative as well.
1212
13+
Certain combinations of `op` and `T` will use the predefined MPI intrinsic operations,
14+
otherwise it will wrap the function in a Julia closure at runtime. The macro [`@Op`](@ref)
15+
can be used to wrap functions ahead of time, which may reduce runtime overhead, and is
16+
required on platforms where closures are not supported (such as ARM and PPC).
17+
18+
User usage of this function is generally unnecessary since it will be called directly
19+
by the relevant MPI collective operations.
20+
1321
## See also
1422
1523
- [`Reduce!`](@ref)/[`Reduce`](@ref)
1624
- [`Allreduce!`](@ref)/[`Allreduce`](@ref)
1725
- [`Scan!`](@ref)/[`Scan`](@ref)
1826
- [`Exscan!`](@ref)/[`Exscan`](@ref)
27+
1928
"""
20-
@mpi_handle Op MPI_Op fptr
29+
@mpi_handle Op MPI_Op cfunc::Union{Base.CFunction, Nothing}
2130

2231
const OP_NULL = _Op(MPI_OP_NULL, nothing)
2332
const BAND = _Op(MPI_BAND, nothing)
@@ -74,16 +83,56 @@ function Op(f, T=Any; iscommutative=false)
7483
error("User-defined reduction operators are not supported on 32-bit Windows.\nSee https://github.com/JuliaParallel/MPI.jl/issues/246 for more details.")
7584
end
7685
w = OpWrapper{typeof(f),T}(f)
77-
fptr = @cfunction($w, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{MPI_Datatype}))
78-
79-
op = Op(OP_NULL.val, fptr)
86+
cfunc = @cfunction($w, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{MPI_Datatype}))
87+
88+
op = Op(OP_NULL.val, cfunc)
8089
# int MPI_Op_create(MPI_User_function* user_fn, int commute, MPI_Op* op)
8190
@mpichk ccall((:MPI_Op_create, libmpi), Cint,
8291
(Ptr{Cvoid}, Cint, Ptr{MPI_Op}),
83-
fptr, iscommutative, op)
92+
cfunc, iscommutative, op)
8493

8594
refcount_inc()
8695
finalizer(free, op)
8796
return op
8897
end
8998

99+
"""
100+
@declareOp(op, T[, iscommutative])
101+
102+
Declare a Julia function `op` to be used as a custom MPI operator [`Op`](@ref) for
103+
variables of type `T`. This will create the [`Op`](@ref) object and define an appropriate
104+
constructor method to `Op`. The `iscommutative` argument indicates to MPI whether or not
105+
MPI can assume the operation is commutative (default is `false`).
106+
107+
The usage of this macro is optional: the main advantage of this is that will avoid the use
108+
of a closure (see ["Closure cfunctions" in the Julia
109+
manual](https://docs.julialang.org/en/v1/manual/calling-c-and-fortran-code/#Closure-cfunctions-1),
110+
which may offer some performance advantages.
111+
112+
This should only be called once per combination of `op` and `T`, and should be at the
113+
top-level (e.g. not inside a function). It can be safely used before `MPI.Init()` and
114+
inside a precompiled module.
115+
"""
116+
macro declareOp(f, T, iscommutative=false)
117+
opwrap = gensym(:opwrap) # we need to manually gensym for use with `@cfunction` macro
118+
quote
119+
if !Base.issingletontype(typeof($(esc(f))))
120+
error("@declareOp macro can only be used with instances of singleton types")
121+
end
122+
const op = Op(OP_NULL.val, nothing)
123+
const $(esc(opwrap)) = OpWrapper{typeof($(esc(f))),$(esc(T))}($(esc(f)))
124+
function initop()
125+
fptr = @cfunction($opwrap, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{MPI_Datatype}))
126+
@mpichk ccall((:MPI_Op_create, libmpi), Cint,
127+
(Ptr{Cvoid}, Cint, Ptr{MPI_Op}),
128+
fptr, $iscommutative, op)
129+
end
130+
if Initialized() && !Finalized()
131+
initop()
132+
else
133+
push!(mpi_init_hooks, initop)
134+
end
135+
MPI.Op(::typeof($(esc(f))), ::Type{$(esc(T))}; iscommutative=$iscommutative) = op
136+
op
137+
end
138+
end

0 commit comments

Comments
 (0)