Skip to content

Commit ee012b7

Browse files
feat: generalize pycall to support structures (#1701)
* feat: generalize pycall to support structures * chore: update test/integration/python.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 55d747d commit ee012b7

File tree

5 files changed

+92
-12
lines changed

5 files changed

+92
-12
lines changed

ext/ReactantPythonCallExt/ReactantPythonCallExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ReactantPythonCallExt
22

33
using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist
4-
using Reactant: Reactant, TracedRArray
4+
using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay
55
using Reactant.Ops: @opcall
66

77
const jaxptr = Ref{Py}()
@@ -57,6 +57,7 @@ function __init__()
5757
return nothing
5858
end
5959

60+
include("overlays.jl")
6061
include("pycall.jl")
6162
include("saved_model.jl")
6263

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
@reactant_overlay function PythonCall.pycall(f::Py, args...)
2+
if Reactant.looped_any(Reactant.use_overlayed_version, args)
3+
return pycall_with_jax_tracing(f, args...)
4+
else
5+
return Base.inferencebarrier(PythonCall.pycall)(f, args...)
6+
end
7+
end
Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,37 @@
1-
function PythonCall.pycall(f::Py, arg0::TracedRArray, argNs::TracedRArray...; kwargs...)
1+
Reactant.jax_dtype_struct_type(::Type{T}) where {T} = Py
2+
3+
function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumber})
24
JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.")
5+
return jaxptr[].ShapeDtypeStruct(
6+
size(x), jnpptr[].dtype(string(NUMPY_SIMPLE_TYPES[Reactant.unwrapped_eltype(x)]))
7+
)
8+
end
39

4-
jax = jaxptr[]
5-
jnp = jnpptr[]
10+
function pycall_with_jax_tracing(f::Py, args...)
11+
JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.")
612

7-
inputs = map((arg0, argNs...)) do arg
8-
jax.ShapeDtypeStruct(
9-
size(arg),
10-
jnp.dtype(string(NUMPY_SIMPLE_TYPES[Reactant.unwrapped_eltype(arg)])),
11-
)
13+
seen_args = Reactant.OrderedIdDict()
14+
jax_inputs = Vector{Any}(undef, length(args))
15+
static_argnums = ()
16+
prev_len = 0
17+
for (i, arg) in enumerate(args)
18+
jax_inputs[i] = Reactant.make_tracer(seen_args, arg, (), Reactant.TracedToJAX)
19+
if length(seen_args) == prev_len
20+
static_argnums = (static_argnums..., i - 1)
21+
end
22+
prev_len = length(seen_args)
1223
end
1324

14-
lowered = jax.jit(f).lower(inputs...)
15-
res = @opcall hlo_call(pyconvert(String, lowered.as_text()), arg0, argNs...)
25+
linear_args = Reactant.TracedType[]
26+
for (k, v) in seen_args
27+
k isa Reactant.TracedType || continue
28+
push!(linear_args, k)
29+
end
1630

17-
return length(res) == 0 ? nothing : res[1]
31+
lowered = jaxptr[].jit(f; static_argnums).lower(jax_inputs...)
32+
# To figure out the exact structure of the pyfunc, we need to execute it. Currently,
33+
# we skip doing that and assume that we are returning nothing, array, or tuple of
34+
# arrays.
35+
res = @opcall hlo_call(pyconvert(String, lowered.as_text()), linear_args...)
36+
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
1837
end

src/Tracing.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
TracedSetPath = 5
77
TracedToTypes = 6
88
NoStopTracedTrack = 7
9+
TracedToJAX = 8
910
end
1011

12+
function convert_to_jax_dtype_struct end
13+
function jax_dtype_struct_type end
14+
1115
struct VisitedObject
1216
id::Int
1317
end
@@ -386,6 +390,8 @@ Base.@nospecializeinfer function traced_type_inner(
386390
}
387391
end
388392
error("Unsupported runtime $runtime")
393+
elseif mode == TracedToJAX
394+
return jax_dtype_struct_type(T)
389395
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
390396
return T
391397
else
@@ -432,6 +438,8 @@ Base.@nospecializeinfer function traced_type_inner(
432438
}
433439
end
434440
error("Unsupported runtime $runtime")
441+
elseif mode == TracedToJAX
442+
return jax_dtype_struct_type(T)
435443
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
436444
return T
437445
else
@@ -1312,6 +1320,17 @@ Base.@nospecializeinfer function make_tracer(
13121320
error("Unsupported runtime $runtime")
13131321
end
13141322

1323+
if mode == TracedToJAX
1324+
haskey(seen, prev) && return seen[prev]
1325+
if !Sharding.is_sharded(sharding)
1326+
res = convert_to_jax_dtype_struct(prev)
1327+
else
1328+
error("TODO: implement sharding")
1329+
end
1330+
seen[prev] = res
1331+
return res
1332+
end
1333+
13151334
throw("Cannot Unknown trace mode $mode")
13161335
end
13171336

@@ -1390,6 +1409,17 @@ Base.@nospecializeinfer function make_tracer(
13901409
error("Unsupported runtime $runtime")
13911410
end
13921411

1412+
if mode == TracedToJAX
1413+
haskey(seen, prev) && return seen[prev]
1414+
if !Sharding.is_sharded(sharding)
1415+
res = convert_to_jax_dtype_struct(prev)
1416+
else
1417+
error("TODO: implement sharding")
1418+
end
1419+
seen[prev] = res
1420+
return res
1421+
end
1422+
13931423
throw("Cannot Unknown trace mode $mode")
13941424
end
13951425

test/integration/python.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,33 @@ fn(x, y) = sin.(x) .+ cos.(y.x[1:2, :])
1111

1212
@testset "PythonCall" begin
1313
jax = pyimport("jax")
14+
jnp = pyimport("jax.numpy")
1415

1516
result = @jit jax.numpy.sum(Reactant.to_rarray(Float32[1, 2, 3]))
1617
@test result isa ConcreteRNumber{Float32}
1718
@test result 6
19+
20+
pyfn = pyfunc(
21+
function (tup, y, num, partial_eval)
22+
return jnp.sin(tup[0]) + jnp.cos(tup[1]) + jnp.sum(y) - num * partial_eval
23+
end;
24+
name="sum_and_sin_cos_jax",
25+
)
26+
27+
tup = (rand(Float32, 3, 4), rand(Float32, 3, 4))
28+
y = rand(Float32, 2, 2)
29+
num = 4.0
30+
partial_eval = 0.5
31+
32+
gt = sin.(tup[1]) .+ cos.(tup[2]) .+ sum(y) .- num .* partial_eval
33+
34+
tup_ra = Reactant.to_rarray(tup)
35+
y_ra = Reactant.to_rarray(y)
36+
num_ra = ConcreteRNumber{Float32}(num)
37+
38+
result = @jit pyfn(tup_ra, y_ra, num_ra, partial_eval)
39+
@test result gt atol = 1e-5
40+
@test result isa ConcreteRArray{Float32,2}
1841
end
1942

2043
@testset "SavedModel Export" begin

0 commit comments

Comments
 (0)