diff --git a/lib/cudadrv/devices.jl b/lib/cudadrv/devices.jl index 543524cd19..89f1d061e6 100644 --- a/lib/cudadrv/devices.jl +++ b/lib/cudadrv/devices.jl @@ -9,7 +9,7 @@ export Get a handle to a compute device. """ -struct CuDevice +struct CuDevice <: AbstractGPUDevice handle::CUdevice function CuDevice(ordinal::Integer) diff --git a/lib/cudadrv/events.jl b/lib/cudadrv/events.jl index 8c4192d9c8..2d7e838e2a 100644 --- a/lib/cudadrv/events.jl +++ b/lib/cudadrv/events.jl @@ -133,3 +133,6 @@ macro elapsed(ex) elapsed(t0, t1) end end + + +Adapt.get_compute_unit_impl(@nospecialize(TypeHistory::Type), e::CuEvent) = device(e.ctx) diff --git a/lib/cusparse/CUSPARSE.jl b/lib/cusparse/CUSPARSE.jl index 3fddf946f3..8b84303c25 100644 --- a/lib/cusparse/CUSPARSE.jl +++ b/lib/cusparse/CUSPARSE.jl @@ -11,7 +11,8 @@ using CEnum: @cenum using LinearAlgebra using LinearAlgebra: HermOrSym -using Adapt: Adapt, adapt +import Adapt +using Adapt: Adapt, adapt, AbstractGPUDevice using SparseArrays diff --git a/lib/cusparse/array.jl b/lib/cusparse/array.jl index 34d52c0a85..333ef13776 100644 --- a/lib/cusparse/array.jl +++ b/lib/cusparse/array.jl @@ -15,6 +15,9 @@ const AbstractCuSparseMatrix{Tv, Ti} = AbstractCuSparseArray{Tv, Ti, 2} Base.convert(T::Type{<:AbstractCuSparseArray}, m::AbstractArray) = m isa T ? m : T(m) +Adapt.get_compute_unit_impl(@nospecialize(TypeHistory::Type), A::AbstractCuSparseArray) = device(A.nzVal) + + mutable struct CuSparseVector{Tv, Ti} <: AbstractCuSparseVector{Tv, Ti} iPtr::CuVector{Ti} nzVal::CuVector{Tv} diff --git a/src/CUDA.jl b/src/CUDA.jl index 395f062d6c..7ad5508c6a 100644 --- a/src/CUDA.jl +++ b/src/CUDA.jl @@ -8,7 +8,8 @@ using LLVM using LLVM.Interop using Core: LLVMPtr -using Adapt: Adapt, adapt, WrappedArray +import Adapt +using Adapt: Adapt, adapt, WrappedArray, AbstractGPUDevice using Requires: @require diff --git a/src/array.jl b/src/array.jl index c72cca5ce7..e4e0f8b1b3 100644 --- a/src/array.jl +++ b/src/array.jl @@ -244,6 +244,13 @@ function device(A::CuArray) return device(A.storage.buffer.ctx) end +Adapt.get_compute_unit_impl(@nospecialize(TypeHistory::Type), A::CuArray) = device(A) + +Adapt.adapt_storage(dev::CuDevice, x) = device!(() -> Adapt.adapt_storage(CuArray, x), dev) + +Sys.total_memory(dev::CuDevice) = CUDA.totalmem(dev) +Sys.free_memory(dev::CuDevice) = unsigned(CUDA.device!(CUDA.available_memory, dev)) + ## derived types