Skip to content

Commit 1e49e09

Browse files
authored
Switch to PrecompileTools.jl. (#601)
1 parent 6c3c8f6 commit 1e49e09

File tree

4 files changed

+45
-65
lines changed

4 files changed

+45
-65
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
99
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1010
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1111
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
12+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1213
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1314
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
1415
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -22,10 +23,11 @@ InteractiveUtils = "1"
2223
LLVM = "8"
2324
Libdl = "1"
2425
Logging = "1"
25-
UUIDs = "1"
26+
PrecompileTools = "1"
2627
Preferences = "1"
2728
Scratch = "1"
2829
Serialization = "1"
2930
TOML = "1"
3031
TimerOutputs = "0.5"
32+
UUIDs = "1"
3133
julia = "1.10"

src/GPUCompiler.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ using Preferences
1616
const CC = Core.Compiler
1717
using Core: MethodInstance, CodeInstance, CodeInfo
1818

19+
compile_cache = nothing # set during __init__()
20+
const pkgver = Base.pkgversion(GPUCompiler)
21+
1922
include("utils.jl")
2023
include("mangling.jl")
2124

@@ -46,12 +49,6 @@ include("execution.jl")
4649
include("reflection.jl")
4750

4851
include("precompile.jl")
49-
_precompile_()
50-
51-
52-
53-
compile_cache = "" # defined in __init__()
54-
const pkgver = Base.pkgversion(GPUCompiler)
5552

5653
function __init__()
5754
STDERR_HAS_COLOR[] = get(stderr, :color, false)

src/precompile.jl

Lines changed: 34 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,41 @@
1-
const __bodyfunction__ = Dict{Method,Any}()
1+
using PrecompileTools: @setup_workload, @compile_workload
22

3-
# Find keyword "body functions" (the function that contains the body
4-
# as written by the developer, called after all missing keyword-arguments
5-
# have been assigned values), in a manner that doesn't depend on
6-
# gensymmed names.
7-
# `mnokw` is the method that gets called when you invoke it without
8-
# supplying any keywords.
9-
function __lookup_kwbody__(mnokw::Method)
10-
function getsym(arg)
11-
isa(arg, Symbol) && return arg
12-
@assert isa(arg, GlobalRef)
13-
return arg.name
3+
@setup_workload begin
4+
precompile_module = @eval module $(gensym())
5+
using ..GPUCompiler
6+
7+
module DummyRuntime
8+
# dummy methods
9+
signal_exception() = return
10+
malloc(sz) = C_NULL
11+
report_oom(sz) = return
12+
report_exception(ex) = return
13+
report_exception_name(ex) = return
14+
report_exception_frame(idx, func, file, line) = return
15+
end
16+
17+
struct DummyCompilerParams <: AbstractCompilerParams end
18+
const DummyCompilerJob = CompilerJob{NativeCompilerTarget, DummyCompilerParams}
19+
20+
GPUCompiler.runtime_module(::DummyCompilerJob) = DummyRuntime
1421
end
1522

16-
f = get(__bodyfunction__, mnokw, nothing)
17-
if f === nothing
18-
fmod = mnokw.module
19-
# The lowered code for `mnokw` should look like
20-
# %1 = mkw(kwvalues..., #self#, args...)
21-
# return %1
22-
# where `mkw` is the name of the "active" keyword body-function.
23-
ast = Base.uncompressed_ast(mnokw)
24-
if isa(ast, Core.CodeInfo) && length(ast.code) >= 2
25-
callexpr = ast.code[end-1]
26-
if isa(callexpr, Expr) && callexpr.head == :call
27-
fsym = callexpr.args[1]
28-
if isa(fsym, Symbol)
29-
f = getfield(fmod, fsym)
30-
elseif isa(fsym, GlobalRef)
31-
if fsym.mod === Core && fsym.name === :_apply
32-
f = getfield(mnokw.module, getsym(callexpr.args[2]))
33-
elseif fsym.mod === Core && fsym.name === :_apply_iterate
34-
f = getfield(mnokw.module, getsym(callexpr.args[3]))
35-
else
36-
f = getfield(fsym.mod, fsym.name)
37-
end
38-
else
39-
f = missing
40-
end
41-
else
42-
f = missing
43-
end
44-
else
45-
f = missing
23+
kernel() = nothing
24+
25+
@compile_workload begin
26+
source = methodinstance(typeof(kernel), Tuple{})
27+
target = NativeCompilerTarget()
28+
params = precompile_module.DummyCompilerParams()
29+
config = CompilerConfig(target, params)
30+
job = CompilerJob(source, config)
31+
32+
JuliaContext() do ctx
33+
# XXX: on Windows, compiling the GPU runtime leaks GPU code in the native cache,
34+
# so prevent building the runtime library (see JuliaGPU/GPUCompiler.jl#601)
35+
GPUCompiler.compile(:asm, job; libraries=false)
4636
end
47-
__bodyfunction__[mnokw] = f
4837
end
49-
return f
50-
end
5138

52-
function _precompile_()
53-
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing
54-
@assert precompile(Tuple{typeof(GPUCompiler.assign_args!),Expr,Vector{Any}})
55-
@assert precompile(Tuple{typeof(GPUCompiler.lower_unreachable!),LLVM.Function})
56-
@assert precompile(Tuple{typeof(GPUCompiler.lower_gc_frame!),LLVM.Function})
57-
@assert precompile(Tuple{typeof(GPUCompiler.lower_throw!),LLVM.Module})
58-
#@assert precompile(Tuple{typeof(GPUCompiler.split_kwargs),Tuple{},Vector{Symbol},Vararg{Vector{Symbol}, N} where N})
59-
# let fbody = try __lookup_kwbody__(which(GPUCompiler.compile, (Symbol,GPUCompiler.CompilerJob,))) catch missing end
60-
# if !ismissing(fbody)
61-
# @assert precompile(fbody, (Bool,Bool,Bool,Bool,Bool,Bool,Bool,typeof(GPUCompiler.compile),Symbol,GPUCompiler.CompilerJob,))
62-
# @assert precompile(fbody, (Bool,Bool,Bool,Bool,Bool,Bool,Bool,typeof(GPUCompiler.compile),Symbol,GPUCompiler.CompilerJob,))
63-
# end
64-
# end
39+
# reset state that was initialized during precompilation
40+
__llvm_initialized[] = false
6541
end

src/rtlib.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ end
120120
const runtime_lock = ReentrantLock()
121121

122122
@locked function load_runtime(@nospecialize(job::CompilerJob))
123+
global compile_cache
124+
if compile_cache === nothing # during precompilation
125+
return build_runtime(job)
126+
end
127+
123128
lock(runtime_lock) do
124129
slug = runtime_slug(job)
125130
if !supports_typed_pointers(context())

0 commit comments

Comments
 (0)