Skip to content

Commit 8a0bff4

Browse files
vchuravywsmoses
andauthored
add custom handler for ptr_to_array runtime call (EnzymeAD#2258)
* add custom handler for ptr_to_array runtime call * Update array.jl * Update llvmrules.jl --------- Co-authored-by: William Moses <[email protected]>
1 parent 2309abd commit 8a0bff4

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

src/compiler/validation.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,13 @@ function __init__()
7878
"jl_get_keyword_sorter",
7979
"ijl_get_keyword_sorter",
8080
"jl_ptr_to_array",
81+
"ijl_ptr_to_array",
8182
"jl_box_float32",
8283
"ijl_box_float32",
8384
"jl_box_float64",
8485
"ijl_box_float64",
8586
"jl_ptr_to_array_1d",
87+
"ijl_ptr_to_array_1d",
8688
"jl_eqtable_get",
8789
"ijl_eqtable_get",
8890
"memcmp",

src/rules/llvmrules.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,57 @@ end
17791779
return nothing
17801780
end
17811781

1782+
@register_fwd function jl_ptr_to_array_fwd(B, orig, gutils, normalR, shadowR)
1783+
if is_constant_inst(gutils, orig)
1784+
return true
1785+
end
1786+
origops = collect(operands(orig))
1787+
width = get_width(gutils)
1788+
shadowin = invert_pointer(gutils, origops[2], B)
1789+
1790+
valTys = API.CValueType[
1791+
API.VT_Primal,
1792+
API.VT_Shadow,
1793+
API.VT_Primal,
1794+
API.VT_Primal,
1795+
]
1796+
1797+
shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))))
1798+
for idx = 1:width
1799+
ev = if width == 1
1800+
shadowin
1801+
else
1802+
extract_value!(B, shadowin, idx - 1)
1803+
end
1804+
1805+
args = LLVM.Value[
1806+
new_from_original(gutils, origops[1]),
1807+
ev, # data
1808+
new_from_original(gutils, origops[3]),
1809+
new_from_original(gutils, origops[4]),
1810+
]
1811+
# TODO do runtime activity relevant errors and checks
1812+
1813+
cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, false) #=lookup=#
1814+
debug_from_orig!(gutils, cal, orig)
1815+
callconv!(cal, callconv(orig))
1816+
if width == 1
1817+
shadowres = cal
1818+
else
1819+
shadowres = insert_value!(B, shadowres, call, idx - 1)
1820+
end
1821+
end
1822+
unsafe_store!(shadowR, shadowres.ref)
1823+
1824+
return false
1825+
end
1826+
@register_aug function jl_ptr_to_array_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
1827+
jl_ptr_to_array_fwd(B, orig, gutils, normalR, shadowR)
1828+
end
1829+
@register_rev function jl_ptr_to_array_rev(B, orig, gutils, tape)
1830+
return nothing
1831+
end
1832+
17821833
@register_fwd function genericmemory_copyto_fwd(B, orig, gutils, normalR, shadowR)
17831834
if is_constant_inst(gutils, orig)
17841835
return true
@@ -2400,6 +2451,12 @@ end
24002451
@revfunc(jl_array_ptr_copy_rev),
24012452
@fwdfunc(jl_array_ptr_copy_fwd),
24022453
)
2454+
register_handler!(
2455+
("jl_ptr_to_array_1d", "ijl_ptr_to_array_1d", "jl_ptr_to_array", "ijl_ptr_to_array"),
2456+
@augfunc(jl_ptr_to_array_augfwd),
2457+
@revfunc(jl_ptr_to_array_rev),
2458+
@fwdfunc(jl_ptr_to_array_fwd),
2459+
)
24032460
register_handler!(
24042461
(),
24052462
@augfunc(jl_unhandled_augfwd),

test/array.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,18 @@ end
2323
@test dB[1] === dA1
2424
@test dB[2] === dA2
2525
end
26+
27+
function unsafe_wrap_test(a, i, x)
28+
GC.@preserve a begin
29+
ptr = pointer(a)
30+
b = Base.unsafe_wrap(Array, ptr, length(a))
31+
b[i] = x
32+
end
33+
a[i]
34+
end
35+
36+
@testset "Unsafe wrap" begin
37+
autodiff(Forward, unsafe_wrap_test, Duplicated(zeros(1), zeros(1)), Const(1), Duplicated(1.0, 2.0))
38+
39+
# TODO test for batch and reverse
40+
end

0 commit comments

Comments
 (0)