Skip to content

Switch to PrecompileTools.jl. #601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand All @@ -22,10 +23,11 @@ InteractiveUtils = "1"
LLVM = "8"
Libdl = "1"
Logging = "1"
UUIDs = "1"
PrecompileTools = "1"
Preferences = "1"
Scratch = "1"
Serialization = "1"
TOML = "1"
TimerOutputs = "0.5"
UUIDs = "1"
julia = "1.10"
9 changes: 3 additions & 6 deletions src/GPUCompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ using Preferences
const CC = Core.Compiler
using Core: MethodInstance, CodeInstance, CodeInfo

compile_cache = nothing # set during __init__()
const pkgver = Base.pkgversion(GPUCompiler)

include("utils.jl")
include("mangling.jl")

Expand Down Expand Up @@ -46,12 +49,6 @@ include("execution.jl")
include("reflection.jl")

include("precompile.jl")
_precompile_()



compile_cache = "" # defined in __init__()
const pkgver = Base.pkgversion(GPUCompiler)

function __init__()
STDERR_HAS_COLOR[] = get(stderr, :color, false)
Expand Down
92 changes: 34 additions & 58 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -1,65 +1,41 @@
const __bodyfunction__ = Dict{Method,Any}()
using PrecompileTools: @setup_workload, @compile_workload

# Find keyword "body functions" (the function that contains the body
# as written by the developer, called after all missing keyword-arguments
# have been assigned values), in a manner that doesn't depend on
# gensymmed names.
# `mnokw` is the method that gets called when you invoke it without
# supplying any keywords.
function __lookup_kwbody__(mnokw::Method)
function getsym(arg)
isa(arg, Symbol) && return arg
@assert isa(arg, GlobalRef)
return arg.name
@setup_workload begin
precompile_module = @eval module $(gensym())
using ..GPUCompiler

module DummyRuntime
# dummy methods
signal_exception() = return
malloc(sz) = C_NULL
report_oom(sz) = return
report_exception(ex) = return
report_exception_name(ex) = return
report_exception_frame(idx, func, file, line) = return
end

struct DummyCompilerParams <: AbstractCompilerParams end
const DummyCompilerJob = CompilerJob{NativeCompilerTarget, DummyCompilerParams}

GPUCompiler.runtime_module(::DummyCompilerJob) = DummyRuntime
end

f = get(__bodyfunction__, mnokw, nothing)
if f === nothing
fmod = mnokw.module
# The lowered code for `mnokw` should look like
# %1 = mkw(kwvalues..., #self#, args...)
# return %1
# where `mkw` is the name of the "active" keyword body-function.
ast = Base.uncompressed_ast(mnokw)
if isa(ast, Core.CodeInfo) && length(ast.code) >= 2
callexpr = ast.code[end-1]
if isa(callexpr, Expr) && callexpr.head == :call
fsym = callexpr.args[1]
if isa(fsym, Symbol)
f = getfield(fmod, fsym)
elseif isa(fsym, GlobalRef)
if fsym.mod === Core && fsym.name === :_apply
f = getfield(mnokw.module, getsym(callexpr.args[2]))
elseif fsym.mod === Core && fsym.name === :_apply_iterate
f = getfield(mnokw.module, getsym(callexpr.args[3]))
else
f = getfield(fsym.mod, fsym.name)
end
else
f = missing
end
else
f = missing
end
else
f = missing
kernel() = nothing

@compile_workload begin
source = methodinstance(typeof(kernel), Tuple{})
target = NativeCompilerTarget()
params = precompile_module.DummyCompilerParams()
config = CompilerConfig(target, params)
job = CompilerJob(source, config)

JuliaContext() do ctx
# XXX: on Windows, compiling the GPU runtime leaks GPU code in the native cache,
# so prevent building the runtime library (see JuliaGPU/GPUCompiler.jl#601)
GPUCompiler.compile(:asm, job; libraries=false)
end
__bodyfunction__[mnokw] = f
end
return f
end

function _precompile_()
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing
@assert precompile(Tuple{typeof(GPUCompiler.assign_args!),Expr,Vector{Any}})
@assert precompile(Tuple{typeof(GPUCompiler.lower_unreachable!),LLVM.Function})
@assert precompile(Tuple{typeof(GPUCompiler.lower_gc_frame!),LLVM.Function})
@assert precompile(Tuple{typeof(GPUCompiler.lower_throw!),LLVM.Module})
#@assert precompile(Tuple{typeof(GPUCompiler.split_kwargs),Tuple{},Vector{Symbol},Vararg{Vector{Symbol}, N} where N})
# let fbody = try __lookup_kwbody__(which(GPUCompiler.compile, (Symbol,GPUCompiler.CompilerJob,))) catch missing end
# if !ismissing(fbody)
# @assert precompile(fbody, (Bool,Bool,Bool,Bool,Bool,Bool,Bool,typeof(GPUCompiler.compile),Symbol,GPUCompiler.CompilerJob,))
# @assert precompile(fbody, (Bool,Bool,Bool,Bool,Bool,Bool,Bool,typeof(GPUCompiler.compile),Symbol,GPUCompiler.CompilerJob,))
# end
# end
# reset state that was initialized during precompilation
__llvm_initialized[] = false
end
5 changes: 5 additions & 0 deletions src/rtlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ end
const runtime_lock = ReentrantLock()

@locked function load_runtime(@nospecialize(job::CompilerJob))
global compile_cache
if compile_cache === nothing # during precompilation
return build_runtime(job)
end

lock(runtime_lock) do
slug = runtime_slug(job)
if !supports_typed_pointers(context())
Expand Down