1
- # https://github.com/EnzymeAD/Enzyme.jl/issues/1516
2
- # On the CPU `autodiff_deferred` can deadlock.
3
- # Hence a specialized CPU version
4
- function cpu_fwd (ctx, f, args... )
5
- EnzymeCore. autodiff (Forward, Const (f), Const{Nothing}, Const (ctx), args... )
6
- return nothing
7
- end
8
-
9
- function gpu_fwd (ctx, f, args... )
1
+ function fwd (ctx, f, args... )
10
2
EnzymeCore. autodiff_deferred (Forward, Const (f), Const{Nothing}, Const (ctx), args... )
11
3
return nothing
12
4
end
13
5
14
6
function EnzymeRules. forward (
15
- func:: Const{<:Kernel{CPU}} ,
16
- :: Type{Const{Nothing}} ,
17
- args... ;
18
- ndrange = nothing ,
19
- workgroupsize = nothing ,
20
- )
21
- kernel = func. val
22
- f = kernel. f
23
- fwd_kernel = similar (kernel, cpu_fwd)
24
-
25
- return fwd_kernel (f, args... ; ndrange, workgroupsize)
26
- end
27
-
28
- function EnzymeRules. forward (
29
- func:: Const{<:Kernel{<:GPU}} ,
7
+ func:: Const{<:Kernel} ,
30
8
:: Type{Const{Nothing}} ,
31
9
args... ;
32
10
ndrange = nothing ,
33
11
workgroupsize = nothing ,
34
12
)
35
13
kernel = func. val
36
14
f = kernel. f
37
- fwd_kernel = similar (kernel, gpu_fwd )
15
+ fwd_kernel = similar (kernel, fwd )
38
16
39
17
return fwd_kernel (f, args... ; ndrange, workgroupsize)
40
18
end
41
19
42
- _enzyme_mkcontext (kernel:: Kernel{CPU} , ndrange, iterspace, dynamic) =
43
- mkcontext (kernel, first (blocks (iterspace)), ndrange, iterspace, dynamic)
44
- _enzyme_mkcontext (kernel:: Kernel{<:GPU} , ndrange, iterspace, dynamic) =
20
+ _enzyme_mkcontext (kernel:: Kernel , ndrange, iterspace, dynamic) =
45
21
mkcontext (kernel, ndrange, iterspace)
46
22
47
- _augmented_return (:: Kernel{CPU} , subtape, arg_refs, tape_type) =
48
- AugmentedReturn {Nothing, Nothing, Tuple{Array, typeof(arg_refs), typeof(tape_type)}} (
49
- nothing ,
50
- nothing ,
51
- (subtape, arg_refs, tape_type),
52
- )
53
- _augmented_return (:: Kernel{<:GPU} , subtape, arg_refs, tape_type) =
23
+ _augmented_return (:: Kernel , subtape, arg_refs, tape_type) =
54
24
AugmentedReturn {Nothing, Nothing, Any} (nothing , nothing , (subtape, arg_refs, tape_type))
55
25
56
26
function _create_tape_kernel (
57
- kernel:: Kernel{CPU} ,
58
- ModifiedBetween,
59
- FT,
60
- ctxTy,
61
- ndrange,
62
- iterspace,
63
- args2... ,
64
- )
65
- TapeType = EnzymeCore. tape_type (
66
- ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween),
67
- FT,
68
- Const{Nothing},
69
- Const{ctxTy},
70
- map (Core. Typeof, args2)... ,
71
- )
72
- subtape = Array {TapeType} (undef, size (blocks (iterspace)))
73
- aug_kernel = similar (kernel, cpu_aug_fwd)
74
- return TapeType, subtape, aug_kernel
75
- end
76
-
77
- function _create_tape_kernel (
78
- kernel:: Kernel{<:GPU} ,
27
+ kernel:: Kernel ,
79
28
ModifiedBetween,
80
29
FT,
81
30
ctxTy,
@@ -104,60 +53,11 @@ function _create_tape_kernel(
104
53
# Allocate per thread
105
54
subtape = allocate (backend (kernel), TapeType, prod (ndrange))
106
55
107
- aug_kernel = similar (kernel, gpu_aug_fwd )
56
+ aug_kernel = similar (kernel, aug_fwd )
108
57
return TapeType, subtape, aug_kernel
109
58
end
110
59
111
- _create_rev_kernel (kernel:: Kernel{CPU} ) = similar (kernel, cpu_rev)
112
- _create_rev_kernel (kernel:: Kernel{<:GPU} ) = similar (kernel, gpu_rev)
113
-
114
- function cpu_aug_fwd (
115
- ctx,
116
- f:: FT ,
117
- :: Val{ModifiedBetween} ,
118
- subtape,
119
- :: Val{TapeType} ,
120
- args... ,
121
- ) where {ModifiedBetween, FT, TapeType}
122
- # A2 = Const{Nothing} -- since f->Nothing
123
- forward, _ = EnzymeCore. autodiff_thunk (
124
- ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)),
125
- Const{Core. Typeof (f)},
126
- Const{Nothing},
127
- Const{Core. Typeof (ctx)},
128
- map (Core. Typeof, args)... ,
129
- )
130
-
131
- # On the CPU: F is a per block function
132
- # On the CPU: subtape::Vector{Vector}
133
- I = __index_Group_Cartesian (ctx, CartesianIndex (1 , 1 )) #= fake=#
134
- subtape[I] = forward (Const (f), Const (ctx), args... )[1 ]
135
- return nothing
136
- end
137
-
138
- function cpu_rev (
139
- ctx,
140
- f:: FT ,
141
- :: Val{ModifiedBetween} ,
142
- subtape,
143
- :: Val{TapeType} ,
144
- args... ,
145
- ) where {ModifiedBetween, FT, TapeType}
146
- _, reverse = EnzymeCore. autodiff_thunk (
147
- ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)),
148
- Const{Core. Typeof (f)},
149
- Const{Nothing},
150
- Const{Core. Typeof (ctx)},
151
- map (Core. Typeof, args)... ,
152
- )
153
- I = __index_Group_Cartesian (ctx, CartesianIndex (1 , 1 )) #= fake=#
154
- tp = subtape[I]
155
- reverse (Const (f), Const (ctx), args... , tp)
156
- return nothing
157
- end
158
-
159
- # GPU support
160
- function gpu_aug_fwd (
60
+ function aug_fwd (
161
61
ctx,
162
62
f:: FT ,
163
63
:: Val{ModifiedBetween} ,
@@ -184,7 +84,7 @@ function gpu_aug_fwd(
184
84
return nothing
185
85
end
186
86
187
- function gpu_rev (
87
+ function rev (
188
88
ctx,
189
89
f:: FT ,
190
90
:: Val{ModifiedBetween} ,
@@ -232,11 +132,7 @@ function EnzymeRules.augmented_primal(
232
132
arg_refs = ntuple (Val (N)) do i
233
133
Base. @_inline_meta
234
134
if args[i] isa Active
235
- if func. val isa Kernel{<: GPU }
236
- error (" Active kernel arguments not supported on GPU" )
237
- else
238
- Ref (EnzymeCore. make_zero (args[i]. val))
239
- end
135
+ error (" Active kernel arguments not supported" )
240
136
else
241
137
nothing
242
138
end
@@ -292,7 +188,7 @@ function EnzymeRules.reverse(
292
188
293
189
ModifiedBetween = Val ((overwritten (config)[1 ], false , overwritten (config)[2 : end ]. .. ))
294
190
295
- rev_kernel = _create_rev_kernel (kernel)
191
+ rev_kernel = similar (kernel, rev )
296
192
rev_kernel (
297
193
f,
298
194
ModifiedBetween,
0 commit comments