diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 45ad877c0f..309d6da4e8 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -253,8 +253,10 @@ steps: using Pkg println("--- :julia: Instantiating project") + Pkg.resolve() Pkg.instantiate() Pkg.activate("perf") + Pkg.resolve() Pkg.instantiate() push!(LOAD_PATH, @__DIR__) diff --git a/lib/cublas/CUBLAS.jl b/lib/cublas/CUBLAS.jl index f102351bb5..782c4dffb5 100644 --- a/lib/cublas/CUBLAS.jl +++ b/lib/cublas/CUBLAS.jl @@ -19,7 +19,7 @@ import LLVM using LLVM.Interop: assume using CEnum: @cenum - +using TaskLocalValues # core library include("libcublas.jl") @@ -73,14 +73,15 @@ end const idle_handles = HandleCache{CuContext,cublasHandle_t}() const idle_xt_handles = HandleCache{Any,cublasXtHandle_t}() +const LIBRARY_STATE = @NamedTuple{handle::cublasHandle_t, stream::CuStream, math_mode::CUDA.MathMode} +const CUBLAS_STATE = + TaskLocalValue{Dict{CuContext,LibraryState}}(()-> Dict{CuContext,LibraryState}()) + function handle() cuda = CUDA.active_state() # every task maintains library state per device - LibraryState = @NamedTuple{handle::cublasHandle_t, stream::CuStream, math_mode::CUDA.MathMode} - states = get!(task_local_storage(), :CUBLAS) do - Dict{CuContext,LibraryState}() - end::Dict{CuContext,LibraryState} + states = CUBLAS_STATE[] # get library state @noinline function new_state(cuda) diff --git a/lib/cudadrv/state.jl b/lib/cudadrv/state.jl index b99134596a..b56f3430c1 100644 --- a/lib/cudadrv/state.jl +++ b/lib/cudadrv/state.jl @@ -65,33 +65,20 @@ function validate_task_local_state(state::TaskLocalState) return state end -# get or create the task local state, and make sure it's valid -function task_local_state!(args...) - tls = task_local_storage() - if haskey(tls, :CUDA) - validate_task_local_state(@inbounds(tls[:CUDA])::TaskLocalState) - else - # verify that CUDA.jl is functional. this doesn't belong here, but since we can't - # error during `__init__`, we do it here instead as this is the first function - # that's likely executed when using CUDA.jl - @assert functional(true) +const CUDA_STATE = TaskLocalValue{TaskLocalState}() do + # verify that CUDA.jl is functional. this doesn't belong here, but since we can't + # error during `__init__`, we do it here instead as this is the first function + # that's likely executed when using CUDA.jl + @assert functional(true) - tls[:CUDA] = TaskLocalState(args...) - end::TaskLocalState + return TaskLocalState() end -# only get the task local state (it may be invalid!), or return nothing if unitialized -function task_local_state() - tls = task_local_storage() - if haskey(tls, :CUDA) - @inbounds(tls[:CUDA]) - else - nothing - end::Union{TaskLocalState,Nothing} -end +# get or create the task local state, and make sure it's valid +task_local_state() = validate_task_local_state(CUDA_STATE[]) @inline function prepare_cuda_state() - state = task_local_state!() + state = task_local_state() # NOTE: current_context() is too slow to use here (taking a lock, accessing a dict) # so we use the raw handle. is that safe though, when we reset the device? @@ -109,7 +96,7 @@ end # without querying task local storage multiple times @inline function active_state() # inline to remove unused state properties - state = task_local_state!() + state = task_local_state() return (device=state.device, context=state.context, stream=stream(state), math_mode=state.math_mode, math_precision=state.math_precision) end @@ -125,7 +112,7 @@ Get or create a CUDA context for the current thread (as opposed to current thread). """ function context() - task_local_state!().context + task_local_state().context end """ @@ -144,19 +131,12 @@ function context!(ctx::CuContext) # NOTE: if we actually need to switch contexts, we eagerly activate it so that we can # query its device (we normally only do so lazily in `prepare_cuda_state`) state = task_local_state() - if state === nothing - old_ctx = nothing + old_ctx = state.context + if old_ctx != ctx activate(ctx) dev = current_device() - task_local_state!(dev, ctx) - else - old_ctx = state.context - if old_ctx != ctx - activate(ctx) - dev = current_device() - state.device = dev - state.context = ctx - end + state.device = dev + state.context = ctx end return old_ctx @@ -169,7 +149,7 @@ end try f() finally - if old_ctx !== nothing && old_ctx != ctx && isvalid(old_ctx) + if old_ctx != ctx && isvalid(old_ctx) context!(old_ctx) end end @@ -188,7 +168,7 @@ Get the CUDA device for the current thread, similar to how [`context()`](@ref) w compared to [`current_context()`](@ref). """ function device() - task_local_state!().device + task_local_state().device end const __device_contexts = LazyInitialized{Vector{Union{Nothing,CuContext}}}() @@ -286,12 +266,8 @@ function device!(dev::CuDevice, flags=nothing) # switch contexts ctx = context(dev) state = task_local_state() - if state === nothing - task_local_state!(dev) - else - state.device = dev - state.context = ctx - end + state.device = dev + state.context = ctx activate(ctx) dev @@ -349,7 +325,7 @@ deviceid(dev::CuDevice=device()) = Int(convert(CUdevice, dev)) ## math mode function math_mode!(mode::MathMode; precision=nothing) - state = task_local_state!() + state = task_local_state() state.math_mode = mode default_math_mode[] = mode @@ -362,8 +338,8 @@ function math_mode!(mode::MathMode; precision=nothing) return end -math_mode() = task_local_state!().math_mode -math_precision() = task_local_state!().math_precision +math_mode() = task_local_state().math_mode +math_precision() = task_local_state().math_precision ## streams @@ -373,7 +349,7 @@ math_precision() = task_local_state!().math_precision Get the CUDA stream that should be used as the default one for the currently executing task. """ -@inline function stream(state=task_local_state!()) +@inline function stream(state=task_local_state()) # @inline so that it can be DCE'd when unused from active_state devidx = deviceid(state.device)+1 @inbounds if state.streams[devidx] === nothing @@ -396,14 +372,14 @@ end end function stream!(stream::CuStream) - state = task_local_state!() + state = task_local_state() devidx = deviceid(state.device)+1 state.streams[devidx] = stream return end function stream!(f::Function, stream::CuStream) - state = task_local_state!() + state = task_local_state() devidx = deviceid(state.device)+1 old_stream = state.streams[devidx] state.streams[devidx] = stream diff --git a/lib/cudnn/src/cuDNN.jl b/lib/cudnn/src/cuDNN.jl index 77f350f8b9..951ec6a2ab 100644 --- a/lib/cudnn/src/cuDNN.jl +++ b/lib/cudnn/src/cuDNN.jl @@ -13,6 +13,7 @@ using CUDA: CUstream, libraryPropertyType using CUDA: retry_reclaim, isdebug, initialize_context using CEnum: @cenum +using TaskLocalValues if CUDA.local_toolkit using CUDA_Runtime_Discovery @@ -65,14 +66,15 @@ end # cache for created, but unused handles const idle_handles = HandleCache{CuContext,cudnnHandle_t}() +const LibraryState = @NamedTuple{handle::cudnnHandle_t, stream::CuStream} +const cuDNN_STATE = + TaskLocalValue{Dict{CuContext,LibraryState}}(()-> Dict{CuContext,LibraryState}()) + function handle() cuda = CUDA.active_state() # every task maintains library state per device - LibraryState = @NamedTuple{handle::cudnnHandle_t, stream::CuStream} - states = get!(task_local_storage(), :cuDNN) do - Dict{CuContext,LibraryState}() - end::Dict{CuContext,LibraryState} + states = cuDNN_STATE[] # get library state @noinline function new_state(cuda) diff --git a/perf/Manifest.toml b/perf/Manifest.toml index 2ce488f1a5..4d906a7b4a 100644 --- a/perf/Manifest.toml +++ b/perf/Manifest.toml @@ -1,132 +1,237 @@ # This file is machine-generated - editing it directly is not advised -[[Artifacts]] +julia_version = "1.8.5" +manifest_format = "2.0" +project_hash = "1ce326a27fb6c56582b2e987b4aeb51b674af2bd" + +[[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" -[[Base64]] +[[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -[[BenchmarkTools]] +[[deps.BenchmarkTools]] deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] -git-tree-sha1 = "940001114a0147b6e4d10624276d56d531dd9b49" +git-tree-sha1 = "d9a9701b899b30332bbcb3e1679c41cce81fb0e8" uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -version = "1.2.2" +version = "1.3.2" + +[[deps.BitFlags]] +git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.7" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "02aa26a4cf76381be7f66e020a3eddeb27b0a092" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.2" -[[Dates]] +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.0.1+0" + +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "5372dbbf8f0bdb8c700db5367132925c0771ef7e" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.2.1" + +[[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" -[[HTTP]] -deps = ["Base64", "Dates", "IniFile", "Logging", "MbedTLS", "NetworkOptions", "Sockets", "URIs"] -git-tree-sha1 = "0fa77022fe4b511826b39c894c90daf5fce3334a" -uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "0.9.17" - -[[IniFile]] +[[deps.ExceptionUnwrapping]] deps = ["Test"] -git-tree-sha1 = "098e4d2c533924c921f9f9847274f2ad89e018b8" -uuid = "83e8ac13-25f8-5344-8a64-a9f2b223428f" -version = "0.5.0" +git-tree-sha1 = "e90caa41f5a86296e014e148ee061bd6c3edec96" +uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" +version = "0.1.9" + +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "cb56ccdd481c0dd7f975ad2b3b62d9eda088f7e2" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "1.9.14" -[[InteractiveUtils]] +[[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -[[JSON]] +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "8076680b162ada2a031f707ac7b4953e30667a37" +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.2" +version = "0.21.4" -[[Libdl]] +[[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" -[[LinearAlgebra]] -deps = ["Libdl"] +[[deps.LinearAlgebra]] +deps = ["Libdl", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -[[Logging]] +[[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[Markdown]] +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "0d097476b6c381ab7906460ef1ef1638fbce1d91" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.2" + +[[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -[[MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "Random", "Sockets"] -git-tree-sha1 = "1c38e51c3d08ef2278062ebceade0e46cefc96fe" +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"] +git-tree-sha1 = "03a9b9718f5682ecb107ac9f7308991db4ce395b" uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.0.3" +version = "1.1.7" -[[MbedTLS_jll]] +[[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.0+0" -[[Mmap]] +[[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" -[[NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.2.1" -[[Parsers]] -deps = ["Dates"] -git-tree-sha1 = "92f91ba9e5941fc781fecf5494ac1da87bdac775" +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.20+0" + +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "51901a49222b09e3743c65b8847687ae5fc78eb2" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.4.1" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "e78db7bd5c26fc5a6911b50a47ee302219157ea8" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.0.10+0" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.2.0" +version = "2.7.2" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.0" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.0" -[[Printf]] +[[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" -[[Profile]] +[[deps.Profile]] deps = ["Printf"] uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" -[[Random]] -deps = ["Serialization"] +[[deps.Random]] +deps = ["SHA", "Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[[SHA]] +[[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" -[[Serialization]] +[[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -[[Sockets]] +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + +[[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" -[[SparseArrays]] +[[deps.SparseArrays]] deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -[[StableRNGs]] +[[deps.StableRNGs]] deps = ["Random", "Test"] git-tree-sha1 = "3be7d49667040add7ee151fefaf1f8c04c8c8276" uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" version = "1.0.0" -[[StaticArrays]] -deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "2ae4fe21e97cd13efd857462c1869b73c9f61be3" +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] +git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.3.2" +version = "1.6.2" -[[Statistics]] +[[deps.StaticArraysCore]] +git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.2" + +[[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -[[Test]] +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.0" + +[[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[[URIs]] -git-tree-sha1 = "97bbe755a53fe859669cd907f2d96aee8d2c1355" +[[deps.TranscodingStreams]] +deps = ["Random", "Test"] +git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.9.13" + +[[deps.URIs]] +git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.3.0" +version = "1.5.0" -[[UUIDs]] +[[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" -[[Unicode]] +[[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.12+3" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.1.1+0" diff --git a/src/pool.jl b/src/pool.jl index 11b9aee724..7e99da4581 100644 --- a/src/pool.jl +++ b/src/pool.jl @@ -320,7 +320,7 @@ struct OutOfGPUMemoryError <: Exception info::Union{Nothing,MemoryInfo} function OutOfGPUMemoryError(sz::Integer=0) - info = if task_local_state() === nothing + info = if false && task_local_state() === nothing # if this error was triggered before the TLS was initialized, we should not try to # fetch memory info as those API calls will just trigger TLS initialization again. nothing