|
1 |
| -function PythonCall.pycall(f::Py, arg0::TracedRArray, argNs::TracedRArray...; kwargs...) |
| 1 | +Reactant.jax_dtype_struct_type(::Type{T}) where {T} = Py |
| 2 | + |
| 3 | +function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumber}) |
2 | 4 | JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.")
|
| 5 | + return jaxptr[].ShapeDtypeStruct( |
| 6 | + size(x), jnpptr[].dtype(string(NUMPY_SIMPLE_TYPES[Reactant.unwrapped_eltype(x)])) |
| 7 | + ) |
| 8 | +end |
3 | 9 |
|
4 |
| - jax = jaxptr[] |
5 |
| - jnp = jnpptr[] |
| 10 | +function pycall_with_jax_tracing(f::Py, args...) |
| 11 | + JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.") |
6 | 12 |
|
7 |
| - inputs = map((arg0, argNs...)) do arg |
8 |
| - jax.ShapeDtypeStruct( |
9 |
| - size(arg), |
10 |
| - jnp.dtype(string(NUMPY_SIMPLE_TYPES[Reactant.unwrapped_eltype(arg)])), |
11 |
| - ) |
| 13 | + seen_args = Reactant.OrderedIdDict() |
| 14 | + jax_inputs = Vector{Any}(undef, length(args)) |
| 15 | + static_argnums = () |
| 16 | + prev_len = 0 |
| 17 | + for (i, arg) in enumerate(args) |
| 18 | + jax_inputs[i] = Reactant.make_tracer(seen_args, arg, (), Reactant.TracedToJAX) |
| 19 | + if length(seen_args) == prev_len |
| 20 | + static_argnums = (static_argnums..., i - 1) |
| 21 | + end |
| 22 | + prev_len = length(seen_args) |
12 | 23 | end
|
13 | 24 |
|
14 |
| - lowered = jax.jit(f).lower(inputs...) |
15 |
| - res = @opcall hlo_call(pyconvert(String, lowered.as_text()), arg0, argNs...) |
| 25 | + linear_args = Reactant.TracedType[] |
| 26 | + for (k, v) in seen_args |
| 27 | + k isa Reactant.TracedType || continue |
| 28 | + push!(linear_args, k) |
| 29 | + end |
16 | 30 |
|
17 |
| - return length(res) == 0 ? nothing : res[1] |
| 31 | + lowered = jaxptr[].jit(f; static_argnums).lower(jax_inputs...) |
| 32 | + # To figure out the exact structure of the pyfunc, we need to execute it. Currently, |
| 33 | + # we skip doing that and assume that we are returning nothing, array, or tuple of |
| 34 | + # arrays. |
| 35 | + res = @opcall hlo_call(pyconvert(String, lowered.as_text()), linear_args...) |
| 36 | + return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res) |
18 | 37 | end
|
0 commit comments