@@ -1779,6 +1779,57 @@ end
1779
1779
return nothing
1780
1780
end
1781
1781
1782
+ @register_fwd function jl_ptr_to_array_fwd (B, orig, gutils, normalR, shadowR)
1783
+ if is_constant_inst (gutils, orig)
1784
+ return true
1785
+ end
1786
+ origops = collect (operands (orig))
1787
+ width = get_width (gutils)
1788
+ shadowin = invert_pointer (gutils, origops[2 ], B)
1789
+
1790
+ valTys = API. CValueType[
1791
+ API. VT_Primal,
1792
+ API. VT_Shadow,
1793
+ API. VT_Primal,
1794
+ API. VT_Primal,
1795
+ ]
1796
+
1797
+ shadowres = UndefValue (LLVM. LLVMType (API. EnzymeGetShadowType (width, value_type (orig))))
1798
+ for idx = 1 : width
1799
+ ev = if width == 1
1800
+ shadowin
1801
+ else
1802
+ extract_value! (B, shadowin, idx - 1 )
1803
+ end
1804
+
1805
+ args = LLVM. Value[
1806
+ new_from_original (gutils, origops[1 ]),
1807
+ ev, # data
1808
+ new_from_original (gutils, origops[3 ]),
1809
+ new_from_original (gutils, origops[4 ]),
1810
+ ]
1811
+ # TODO do runtime activity relevant errors and checks
1812
+
1813
+ cal = call_samefunc_with_inverted_bundles! (B, gutils, orig, args, valTys, false ) #= lookup=#
1814
+ debug_from_orig! (gutils, cal, orig)
1815
+ callconv! (cal, callconv (orig))
1816
+ if width == 1
1817
+ shadowres = cal
1818
+ else
1819
+ shadowres = insert_value! (B, shadowres, call, idx - 1 )
1820
+ end
1821
+ end
1822
+ unsafe_store! (shadowR, shadowres. ref)
1823
+
1824
+ return false
1825
+ end
1826
+ @register_aug function jl_ptr_to_array_augfwd (B, orig, gutils, normalR, shadowR, tapeR)
1827
+ jl_ptr_to_array_fwd (B, orig, gutils, normalR, shadowR)
1828
+ end
1829
+ @register_rev function jl_ptr_to_array_rev (B, orig, gutils, tape)
1830
+ return nothing
1831
+ end
1832
+
1782
1833
@register_fwd function genericmemory_copyto_fwd (B, orig, gutils, normalR, shadowR)
1783
1834
if is_constant_inst (gutils, orig)
1784
1835
return true
@@ -2400,6 +2451,12 @@ end
2400
2451
@revfunc (jl_array_ptr_copy_rev),
2401
2452
@fwdfunc (jl_array_ptr_copy_fwd),
2402
2453
)
2454
+ register_handler! (
2455
+ (" jl_ptr_to_array_1d" , " ijl_ptr_to_array_1d" , " jl_ptr_to_array" , " ijl_ptr_to_array" ),
2456
+ @augfunc (jl_ptr_to_array_augfwd),
2457
+ @revfunc (jl_ptr_to_array_rev),
2458
+ @fwdfunc (jl_ptr_to_array_fwd),
2459
+ )
2403
2460
register_handler! (
2404
2461
(),
2405
2462
@augfunc (jl_unhandled_augfwd),
0 commit comments