Skip to content

Commit 043270b

Browse files
committed
remove special handling of CPU
1 parent 2716056 commit 043270b

File tree

2 files changed

+20
-229
lines changed

2 files changed

+20
-229
lines changed

ext/EnzymeCore07Ext.jl

Lines changed: 11 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,30 @@
1-
# https://github.com/EnzymeAD/Enzyme.jl/issues/1516
2-
# On the CPU `autodiff_deferred` can deadlock.
3-
# Hence a specialized CPU version
4-
function cpu_fwd(ctx, f, args...)
5-
EnzymeCore.autodiff(Forward, Const(f), Const{Nothing}, Const(ctx), args...)
6-
return nothing
7-
end
8-
9-
function gpu_fwd(ctx, f, args...)
1+
function fwd(ctx, f, args...)
102
EnzymeCore.autodiff_deferred(Forward, Const(f), Const{Nothing}, Const(ctx), args...)
113
return nothing
124
end
135

146
function EnzymeRules.forward(
15-
func::Const{<:Kernel{CPU}},
16-
::Type{Const{Nothing}},
17-
args...;
18-
ndrange = nothing,
19-
workgroupsize = nothing,
20-
)
21-
kernel = func.val
22-
f = kernel.f
23-
fwd_kernel = similar(kernel, cpu_fwd)
24-
25-
return fwd_kernel(f, args...; ndrange, workgroupsize)
26-
end
27-
28-
function EnzymeRules.forward(
29-
func::Const{<:Kernel{<:GPU}},
7+
func::Const{<:Kernel},
308
::Type{Const{Nothing}},
319
args...;
3210
ndrange = nothing,
3311
workgroupsize = nothing,
3412
)
3513
kernel = func.val
3614
f = kernel.f
37-
fwd_kernel = similar(kernel, gpu_fwd)
15+
fwd_kernel = similar(kernel, fwd)
3816

3917
return fwd_kernel(f, args...; ndrange, workgroupsize)
4018
end
4119

42-
_enzyme_mkcontext(kernel::Kernel{CPU}, ndrange, iterspace, dynamic) =
43-
mkcontext(kernel, first(blocks(iterspace)), ndrange, iterspace, dynamic)
44-
_enzyme_mkcontext(kernel::Kernel{<:GPU}, ndrange, iterspace, dynamic) =
20+
_enzyme_mkcontext(kernel::Kernel, ndrange, iterspace, dynamic) =
4521
mkcontext(kernel, ndrange, iterspace)
4622

47-
_augmented_return(::Kernel{CPU}, subtape, arg_refs, tape_type) =
48-
AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs), typeof(tape_type)}}(
49-
nothing,
50-
nothing,
51-
(subtape, arg_refs, tape_type),
52-
)
53-
_augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type) =
23+
_augmented_return(::Kernel, subtape, arg_refs, tape_type) =
5424
AugmentedReturn{Nothing, Nothing, Any}(nothing, nothing, (subtape, arg_refs, tape_type))
5525

5626
function _create_tape_kernel(
57-
kernel::Kernel{CPU},
58-
ModifiedBetween,
59-
FT,
60-
ctxTy,
61-
ndrange,
62-
iterspace,
63-
args2...,
64-
)
65-
TapeType = EnzymeCore.tape_type(
66-
ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween),
67-
FT,
68-
Const{Nothing},
69-
Const{ctxTy},
70-
map(Core.Typeof, args2)...,
71-
)
72-
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
73-
aug_kernel = similar(kernel, cpu_aug_fwd)
74-
return TapeType, subtape, aug_kernel
75-
end
76-
77-
function _create_tape_kernel(
78-
kernel::Kernel{<:GPU},
27+
kernel::Kernel,
7928
ModifiedBetween,
8029
FT,
8130
ctxTy,
@@ -104,60 +53,11 @@ function _create_tape_kernel(
10453
# Allocate per thread
10554
subtape = allocate(backend(kernel), TapeType, prod(ndrange))
10655

107-
aug_kernel = similar(kernel, gpu_aug_fwd)
56+
aug_kernel = similar(kernel, aug_fwd)
10857
return TapeType, subtape, aug_kernel
10958
end
11059

111-
_create_rev_kernel(kernel::Kernel{CPU}) = similar(kernel, cpu_rev)
112-
_create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev)
113-
114-
function cpu_aug_fwd(
115-
ctx,
116-
f::FT,
117-
::Val{ModifiedBetween},
118-
subtape,
119-
::Val{TapeType},
120-
args...,
121-
) where {ModifiedBetween, FT, TapeType}
122-
# A2 = Const{Nothing} -- since f->Nothing
123-
forward, _ = EnzymeCore.autodiff_thunk(
124-
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
125-
Const{Core.Typeof(f)},
126-
Const{Nothing},
127-
Const{Core.Typeof(ctx)},
128-
map(Core.Typeof, args)...,
129-
)
130-
131-
# On the CPU: F is a per block function
132-
# On the CPU: subtape::Vector{Vector}
133-
I = __index_Group_Cartesian(ctx, CartesianIndex(1, 1)) #=fake=#
134-
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
135-
return nothing
136-
end
137-
138-
function cpu_rev(
139-
ctx,
140-
f::FT,
141-
::Val{ModifiedBetween},
142-
subtape,
143-
::Val{TapeType},
144-
args...,
145-
) where {ModifiedBetween, FT, TapeType}
146-
_, reverse = EnzymeCore.autodiff_thunk(
147-
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
148-
Const{Core.Typeof(f)},
149-
Const{Nothing},
150-
Const{Core.Typeof(ctx)},
151-
map(Core.Typeof, args)...,
152-
)
153-
I = __index_Group_Cartesian(ctx, CartesianIndex(1, 1)) #=fake=#
154-
tp = subtape[I]
155-
reverse(Const(f), Const(ctx), args..., tp)
156-
return nothing
157-
end
158-
159-
# GPU support
160-
function gpu_aug_fwd(
60+
function aug_fwd(
16161
ctx,
16262
f::FT,
16363
::Val{ModifiedBetween},
@@ -184,7 +84,7 @@ function gpu_aug_fwd(
18484
return nothing
18585
end
18686

187-
function gpu_rev(
87+
function rev(
18888
ctx,
18989
f::FT,
19090
::Val{ModifiedBetween},
@@ -232,11 +132,7 @@ function EnzymeRules.augmented_primal(
232132
arg_refs = ntuple(Val(N)) do i
233133
Base.@_inline_meta
234134
if args[i] isa Active
235-
if func.val isa Kernel{<:GPU}
236-
error("Active kernel arguments not supported on GPU")
237-
else
238-
Ref(EnzymeCore.make_zero(args[i].val))
239-
end
135+
error("Active kernel arguments not supported")
240136
else
241137
nothing
242138
end
@@ -292,7 +188,7 @@ function EnzymeRules.reverse(
292188

293189
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
294190

295-
rev_kernel = _create_rev_kernel(kernel)
191+
rev_kernel = similar(kernel, rev)
296192
rev_kernel(
297193
f,
298194
ModifiedBetween,

ext/EnzymeCore08Ext.jl

Lines changed: 9 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,11 @@
1-
# https://github.com/EnzymeAD/Enzyme.jl/issues/1516
2-
# On the CPU `autodiff_deferred` can deadlock.
3-
# Hence a specialized CPU version
4-
function cpu_fwd(ctx, config, f, args...)
5-
EnzymeCore.autodiff(EnzymeCore.set_runtime_activity(Forward, config), Const(f), Const{Nothing}, Const(ctx), args...)
6-
return nothing
7-
end
8-
9-
function gpu_fwd(ctx, config, f, args...)
1+
function fwd(ctx, config, f, args...)
102
EnzymeCore.autodiff_deferred(EnzymeCore.set_runtime_activity(Forward, config), Const(f), Const{Nothing}, Const(ctx), args...)
113
return nothing
124
end
135

146
function EnzymeRules.forward(
157
config,
16-
func::Const{<:Kernel{CPU}},
17-
::Type{Const{Nothing}},
18-
args...;
19-
ndrange = nothing,
20-
workgroupsize = nothing,
21-
)
22-
kernel = func.val
23-
f = kernel.f
24-
fwd_kernel = similar(kernel, cpu_fwd)
25-
26-
return fwd_kernel(config, f, args...; ndrange, workgroupsize)
27-
end
28-
29-
function EnzymeRules.forward(
30-
config,
31-
func::Const{<:Kernel{<:GPU}},
8+
func::Const{<:Kernel},
329
::Type{Const{Nothing}},
3310
args...;
3411
ndrange = nothing,
@@ -41,41 +18,12 @@ function EnzymeRules.forward(
4118
return fwd_kernel(config, f, args...; ndrange, workgroupsize)
4219
end
4320

44-
_enzyme_mkcontext(kernel::Kernel{CPU}, ndrange, iterspace, dynamic) =
45-
mkcontext(kernel, first(blocks(iterspace)), ndrange, iterspace, dynamic)
46-
_enzyme_mkcontext(kernel::Kernel{<:GPU}, ndrange, iterspace, dynamic) =
21+
_enzyme_mkcontext(kernel::Kernel, ndrange, iterspace, dynamic) =
4722
mkcontext(kernel, ndrange, iterspace)
4823

49-
_augmented_return(::Kernel{CPU}, subtape, arg_refs, tape_type) =
50-
AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs), typeof(tape_type)}}(
51-
nothing,
52-
nothing,
53-
(subtape, arg_refs, tape_type),
54-
)
55-
_augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type) =
24+
_augmented_return(::Kernel, subtape, arg_refs, tape_type) =
5625
AugmentedReturn{Nothing, Nothing, Any}(nothing, nothing, (subtape, arg_refs, tape_type))
5726

58-
function _create_tape_kernel(
59-
kernel::Kernel{CPU},
60-
Mode,
61-
FT,
62-
ctxTy,
63-
ndrange,
64-
iterspace,
65-
args2...,
66-
)
67-
TapeType = EnzymeCore.tape_type(
68-
Mode,
69-
FT,
70-
Const{Nothing},
71-
Const{ctxTy},
72-
map(Core.Typeof, args2)...,
73-
)
74-
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
75-
aug_kernel = similar(kernel, cpu_aug_fwd)
76-
return TapeType, subtape, aug_kernel
77-
end
78-
7927
function _create_tape_kernel(
8028
kernel::Kernel{<:GPU},
8129
Mode,
@@ -106,60 +54,11 @@ function _create_tape_kernel(
10654
# Allocate per thread
10755
subtape = allocate(backend(kernel), TapeType, prod(ndrange))
10856

109-
aug_kernel = similar(kernel, gpu_aug_fwd)
57+
aug_kernel = similar(kernel, aug_fwd)
11058
return TapeType, subtape, aug_kernel
11159
end
11260

113-
_create_rev_kernel(kernel::Kernel{CPU}) = similar(kernel, cpu_rev)
114-
_create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev)
115-
116-
function cpu_aug_fwd(
117-
ctx,
118-
f::FT,
119-
mode::Mode,
120-
subtape,
121-
::Val{TapeType},
122-
args...,
123-
) where {Mode, FT, TapeType}
124-
# A2 = Const{Nothing} -- since f->Nothing
125-
forward, _ = EnzymeCore.autodiff_thunk(
126-
mode,
127-
Const{Core.Typeof(f)},
128-
Const{Nothing},
129-
Const{Core.Typeof(ctx)},
130-
map(Core.Typeof, args)...,
131-
)
132-
133-
# On the CPU: F is a per block function
134-
# On the CPU: subtape::Vector{Vector}
135-
I = __index_Group_Cartesian(ctx, CartesianIndex(1, 1)) #=fake=#
136-
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
137-
return nothing
138-
end
139-
140-
function cpu_rev(
141-
ctx,
142-
f::FT,
143-
mode::Mode,
144-
subtape,
145-
::Val{TapeType},
146-
args...,
147-
) where {Mode, FT, TapeType}
148-
_, reverse = EnzymeCore.autodiff_thunk(
149-
mode,
150-
Const{Core.Typeof(f)},
151-
Const{Nothing},
152-
Const{Core.Typeof(ctx)},
153-
map(Core.Typeof, args)...,
154-
)
155-
I = __index_Group_Cartesian(ctx, CartesianIndex(1, 1)) #=fake=#
156-
tp = subtape[I]
157-
reverse(Const(f), Const(ctx), args..., tp)
158-
return nothing
159-
end
160-
161-
# GPU support
162-
function gpu_aug_fwd(
61+
function fwd(
16362
ctx,
16463
f::FT,
16564
mode::Mode,
@@ -186,7 +85,7 @@ function gpu_aug_fwd(
18685
return nothing
18786
end
18887

189-
function gpu_rev(
88+
function rev(
19089
ctx,
19190
f::FT,
19291
mode::Mode,
@@ -234,11 +133,7 @@ function EnzymeRules.augmented_primal(
234133
arg_refs = ntuple(Val(N)) do i
235134
Base.@_inline_meta
236135
if args[i] isa Active
237-
if func.val isa Kernel{<:GPU}
238-
error("Active kernel arguments not supported on GPU")
239-
else
240-
Ref(EnzymeCore.make_zero(args[i].val))
241-
end
136+
error("Active kernel arguments not supported")
242137
else
243138
nothing
244139
end
@@ -294,7 +189,7 @@ function EnzymeRules.reverse(
294189

295190
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
296191
Mode = EnzymeCore.set_runtime_activity(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), config)
297-
rev_kernel = _create_rev_kernel(kernel)
192+
rev_kernel = similar(kernel, rev)
298193
rev_kernel(
299194
f,
300195
Mode,

0 commit comments

Comments
 (0)