diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index 8f664ad65e..6e2f869f09 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -3,12 +3,13 @@ Base.@kwdef struct CUDACompilerParams <: AbstractCompilerParams cap::VersionNumber ptx::VersionNumber + link_libdevice::Bool = true # Used by Reactant.jl end function Base.hash(params::CUDACompilerParams, h::UInt) h = hash(params.cap, h) h = hash(params.ptx, h) - + h = hash(params.link_libdevice, h) return h end @@ -18,15 +19,19 @@ const CUDACompilerJob = CompilerJob{PTXCompilerTarget,CUDACompilerParams} GPUCompiler.runtime_module(@nospecialize(job::CUDACompilerJob)) = CUDA # filter out functions from libdevice and cudadevrt -GPUCompiler.isintrinsic(@nospecialize(job::CUDACompilerJob), fn::String) = - invoke(GPUCompiler.isintrinsic, - Tuple{CompilerJob{PTXCompilerTarget}, typeof(fn)}, - job, fn) || - fn == "__nvvm_reflect" || startswith(fn, "cuda") +function GPUCompiler.isintrinsic(@nospecialize(job::CUDACompilerJob), fn::String) + is_intrinsic = invoke(GPUCompiler.isintrinsic, + Tuple{CompilerJob{PTXCompilerTarget}, typeof(fn)}, job, fn) + is_intrinsic |= fn == "__nvvm_reflect" + is_intrinsic |= startswith(fn, "cuda") + is_intrinsic |= !job.config.params.link_libdevice ? startswith(fn, "__nv_") : false # Reactant.jl wants to handle __nv_ functions + return is_intrinsic +end # link libdevice function GPUCompiler.link_libraries!(@nospecialize(job::CUDACompilerJob), mod::LLVM.Module, undefined_fns::Vector{String}) + job.config.params.link_libdevice || return # only link if there's undefined __nv_ functions if !any(fn->startswith(fn, "__nv_"), undefined_fns) return