Skip to content

Bringing GPU programming to DFTK #697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Brillouin = "23470ee3-d0df-4052-8b1a-8cbd6363e7f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DftFunctionals = "6bd331d2-b28d-4fd3-880e-1a1c7f37947f"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
InteratomicPotentials = "a9efe35a-c65d-452d-b8a8-82646cd5cb04"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Expand All @@ -33,6 +35,7 @@ Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
26 changes: 26 additions & 0 deletions examples/gpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using DFTK
using CUDA
using MKL
setup_threading(n_blas=1)

a = 10.263141334305942 # Lattice constant in Bohr
lattice = a / 2 .* [[0 1 1.]; [1 0 1.]; [1 1 0.]]
Si = ElementPsp(:Si, psp=load_psp("hgh/lda/Si-q4"))
atoms = [Si, Si]
positions = [ones(3)/8, -ones(3)/8];
terms_LDA = [Kinetic(), AtomicLocal(), AtomicNonlocal()]

# Setup an LDA model and discretize using
# a single k-point and a small `Ecut` of 5 Hartree.
mod = Model(lattice, atoms, positions; terms=terms_LDA,symmetries=false)
basis = PlaneWaveBasis(mod; Ecut=30, kgrid=(1, 1, 1))
basis_gpu = PlaneWaveBasis(mod; Ecut=30, kgrid=(1, 1, 1), array_type = CuArray)


DFTK.reset_timer!(DFTK.timer)
scfres = self_consistent_field(basis; solver=scf_damping_solver(1.0), is_converged=DFTK.ScfConvergenceDensity(1e-3))
println(DFTK.timer)

DFTK.reset_timer!(DFTK.timer)
scfres_gpu = self_consistent_field(basis_gpu; solver=scf_damping_solver(1.0), is_converged=DFTK.ScfConvergenceDensity(1e-3))
println(DFTK.timer)
4 changes: 4 additions & 0 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ using spglib_jll
using Unitful
using UnitfulAtomic
using ForwardDiff
using AbstractFFTs
using GPUArrays
using CUDA
using Random
Comment on lines +16 to +19
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure they should be here (and a hard dependency of DFTK) long-term.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will need to discuss dependencies (especially if we want to move LOBPCG out of DFTK, that can take some work): I also didn't really know where to put my imports and how they were managed in a big package, so there is room for improvement.

using ChainRulesCore

export Vec3
Expand Down
47 changes: 41 additions & 6 deletions src/Model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,48 @@ Examples of covectors are forces.
Reciprocal vectors are a special case: they are covectors, but conventionally have an
additional factor of 2π in their definition, so they transform rather with 2π times the
inverse lattice transpose: q_cart = 2π lattice' \ q_red = recip_lattice * q_red.

The trans_mat functions return the transition matrices required to do such a change of basis.
=#
vector_red_to_cart(model::Model, rred) = model.lattice * rred
vector_cart_to_red(model::Model, rcart) = model.inv_lattice * rcart
covector_red_to_cart(model::Model, fred) = model.inv_lattice' * fred
covector_cart_to_red(model::Model, fcart) = model.lattice' * fcart
recip_vector_red_to_cart(model::Model, qred) = model.recip_lattice * qred
recip_vector_cart_to_red(model::Model, qcart) = model.inv_recip_lattice * qcart

trans_mat_vector_red_to_cart(model::Model) = model.lattice
trans_mat_vector_cart_to_red(model::Model) = model.inv_lattice
trans_mat_covector_red_to_cart(model::Model) = model.inv_lattice'
trans_mat_covector_cart_to_red(model::Model) = model.lattice'
trans_mat_recip_vector_red_to_cart(model::Model) = model.recip_lattice
trans_mat_recip_vector_cart_to_red(model::Model) = model.inv_recip_lattice

fun_mat_list =(:vector_red_to_cart,
:vector_cart_to_red,
:covector_red_to_cart,
:covector_cart_to_red,
:recip_vector_red_to_cart,
:recip_vector_cart_to_red
)

for fun1 in fun_mat_list
#=
The following functions compute the change of basis for a given vector. To do so,
they call the trans_mat functions to get the corresponding transition matrix.
These functions can be broadcasted over an Array of vectors: however, they are
not GPU compatible, as they require the model, which is no isbits.
=#
@eval $fun1(model::Model, vec) = $(Symbol("trans_mat_"*string(fun1)))(model::Model) * vec
#=
The following functions take an AbstractArray of vectors and compute the change of basis
for every vector in the AbstractArray: they return an AbstractArray of the same type
and size as the input, but containing the vectors in a new basis.
These functions are GPU compatible (ie the AbstractArray can be a GPUArray), since
they use a map and the transition matrices are static arrays.
=#
@eval function $(Symbol("map_"*string(fun1)))(model::Model, A::AbstractArray{AT}) where {AT <: Vec3}
trans_matrix = $(Symbol("trans_mat_"*string(fun1)))(model)
in_new_basis = map(A) do Ai
trans_matrix * Ai
end
in_new_basis
end
end

#=
Transformations on vectors and covectors are matrices and comatrices.
Expand Down
40 changes: 26 additions & 14 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Normalization conventions:

`G_to_r` and `r_to_G` convert between these representations.
"""
struct PlaneWaveBasis{T, VT} <: AbstractBasis{T} where {VT <: Real}
struct PlaneWaveBasis{T, VT, AT, GT, RT} <: AbstractBasis{T} where {VT <: Real, GT <: AT, RT <: AT, AT <: AbstractArray}
# T is the default type to express data, VT the corresponding bare value type (i.e. not dual)
model::Model{T, VT}

Expand Down Expand Up @@ -67,8 +67,8 @@ struct PlaneWaveBasis{T, VT} <: AbstractBasis{T} where {VT <: Real}
G_to_r_normalization::T # G_to_r = G_to_r_normalization * BFFT

# "cubic" basis in reciprocal and real space, on which potentials and densities are stored
G_vectors::Array{Vec3{Int}, 3}
r_vectors::Array{Vec3{VT }, 3}
G_vectors::GT
r_vectors::RT

## MPI-local information of the kpoints this processor treats
# Irreducible kpoints. In the case of collinear spin,
Expand Down Expand Up @@ -148,7 +148,7 @@ end
# and are stored in PlaneWaveBasis for easy reconstruction.
function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational,
kcoords, kweights, kgrid, kshift,
symmetries_respect_rgrid, comm_kpts) where {T <: Real}
symmetries_respect_rgrid, comm_kpts, array_type = Array) where {T <: Real}
# Validate fft_size
if variational
max_E = sum(abs2, model.recip_lattice * floor.(Int, Vec3(fft_size) ./ 2)) / 2
Expand Down Expand Up @@ -191,7 +191,8 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational,
kweights_global = kweights

# Setup FFT plans
(ipFFT, opFFT, ipBFFT, opBFFT) = build_fft_plans(T, fft_size)
Gs = G_vectors(fft_size, array_type)
(ipFFT, opFFT, ipBFFT, opBFFT) = build_fft_plans(similar(Gs,T), fft_size)

# Normalization constants
# r_to_G = r_to_G_normalization * FFT
Expand Down Expand Up @@ -244,22 +245,25 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational,
end
@assert mpi_sum(sum(kweights_thisproc), comm_kpts) ≈ model.n_spin_components
@assert length(kpoints) == length(kweights_thisproc)
Threads.nthreads() != 1 && Gs isa AbstractGPUArray && error("Can't mix multi-threading and GPU computations yet.")

VT = value_type(T)
dvol = model.unit_cell_volume ./ prod(fft_size)
r_vectors = [Vec3{VT}(VT(i-1) / N1, VT(j-1) / N2, VT(k-1) / N3) for i = 1:N1, j = 1:N2, k = 1:N3]
terms = Vector{Any}(undef, length(model.term_types)) # Dummy terms array, filled below

basis = PlaneWaveBasis{T,value_type(T)}(
RT = array_type{Vec3{VT }, 3}
GT = array_type{Vec3{Int }, 3}

basis = PlaneWaveBasis{T,value_type(T), array_type, GT, RT}(
model, fft_size, dvol,
Ecut, variational,
opFFT, ipFFT, opBFFT, ipBFFT,
r_to_G_normalization, G_to_r_normalization,
G_vectors(fft_size), r_vectors,
Gs, r_vectors,
kpoints, kweights_thisproc, kgrid, kshift,
kcoords_global, kweights_global, comm_kpts, krange_thisproc, krange_allprocs,
symmetries, symmetries_respect_rgrid, terms)

# Instantiate the terms with the basis
for (it, t) in enumerate(model.term_types)
term_name = string(nameof(typeof(t)))
Expand All @@ -277,7 +281,7 @@ end
variational=true, fft_size=nothing,
kgrid=nothing, kshift=nothing,
symmetries_respect_rgrid=isnothing(fft_size),
comm_kpts=MPI.COMM_WORLD) where {T <: Real}
comm_kpts=MPI.COMM_WORLD, array_type = Array) where {T <: Real}
if isnothing(fft_size)
@assert variational
if symmetries_respect_rgrid
Expand All @@ -295,7 +299,7 @@ end
fft_size = compute_fft_size(model, Ecut, kcoords; factors)
end
PlaneWaveBasis(model, Ecut, fft_size, variational, kcoords, kweights,
kgrid, kshift, symmetries_respect_rgrid, comm_kpts)
kgrid, kshift, symmetries_respect_rgrid, comm_kpts, array_type)
end

@doc raw"""
Expand All @@ -322,7 +326,7 @@ Creates a new basis identical to `basis`, but with a custom set of kpoints
PlaneWaveBasis(basis.model, basis.Ecut,
basis.fft_size, basis.variational,
kcoords, kweights, kgrid, kshift,
basis.symmetries_respect_rgrid, basis.comm_kpts)
basis.symmetries_respect_rgrid, basis.comm_kpts, array_type = array_type(basis))
end

"""
Expand All @@ -331,13 +335,15 @@ end
The wave vectors `G` in reduced (integer) coordinates for a cubic basis set
of given sizes.
"""
function G_vectors(fft_size::Union{Tuple,AbstractVector})

function G_vectors(fft_size::Union{Tuple,AbstractVector}, array_type = Array)
# Note that a collect(G_vectors_generator(fft_size)) is 100-fold slower
# than this implementation, hence the code duplication.
start = .- cld.(fft_size .- 1, 2)
stop = fld.(fft_size .- 1, 2)
axes = [[collect(0:stop[i]); collect(start[i]:-1)] for i in 1:3]
[Vec3{Int}(i, j, k) for i in axes[1], j in axes[2], k in axes[3]]
Gs = [Vec3{Int}(i, j, k) for i in axes[1], j in axes[2], k in axes[3]]
convert(array_type, Gs) #Offload to GPU if needed.
end
function G_vectors_generator(fft_size::Union{Tuple,AbstractVector})
# The generator version is used mainly in symmetry.jl for lowpass_for_symmetry! and
Expand All @@ -358,14 +364,20 @@ or a ``k``-point `kpt`.
G_vectors(basis::PlaneWaveBasis) = basis.G_vectors
G_vectors(::PlaneWaveBasis, kpt::Kpoint) = kpt.G_vectors

"""
Return the type of array used for computations (Array if on CPU, CuArray,
ROCArray... if on GPU).
"""
array_type(basis::PlaneWaveBasis{T,VT,AT}) where {T, VT, AT} = AT


@doc raw"""
G_vectors_cart(basis::PlaneWaveBasis)
G_vectors_cart(basis::PlaneWaveBasis, kpt::Kpoint)

The list of ``G`` vectors of a given `basis` or `kpt`, in cartesian coordinates.
"""
G_vectors_cart(basis::PlaneWaveBasis) = recip_vector_red_to_cart.(basis.model, G_vectors(basis))
G_vectors_cart(basis::PlaneWaveBasis) = map_recip_vector_red_to_cart(basis.model, G_vectors(basis))
function G_vectors_cart(basis::PlaneWaveBasis, kpt::Kpoint)
recip_vector_red_to_cart.(basis.model, G_vectors(basis, kpt))
end
Expand Down
3 changes: 2 additions & 1 deletion src/common/ortho.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Orthonormalize
ortho_qr(φk) = Matrix(qr(φk).Q)
ortho_qr(φk::AbstractArray) = Matrix(qr(φk).Q) #LinearAlgebra.QRCompactWYQ -> Matrix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should be ::Array instead. Also it somehow feels wrong to need to put CuArray explicitly here. We should think of a way to generalise this (perhaps also with some "stripping off type arguments" construct as discussed on slack.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we simply do this?

ortho_qr(φk::Array) = Matrix(qr(φk).Q) 
ortho_qr(φk::T) where T <: AbstractGPUArray = T(qr(φk).Q) 

Another way to do it would be to have only one function and to get the the "base type" of φk, then convert qr(φk).Q to this type: this can be done by calling T.name.wrapper (or maybe one day a dedicated function in Base). We would then have the following code:
ortho_qr(φk::T) where T <: AbstractArray = T.name.wrapper(qr(φk).Q)

ortho_qr(φk::CuArray) = CuArray(qr(φk).Q) #CUDA.CUSOLVER.CuQRPackedQ -> CuArray
12 changes: 6 additions & 6 deletions src/densities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@ grid `basis`, where the individual k-points are occupied according to `occupatio
chunk_length = cld(length(ik_n), Threads.nthreads())

# chunk-local variables
ρ_chunklocal = Array{T,4}[zeros(T, basis.fft_size..., basis.model.n_spin_components)
for _ = 1:Threads.nthreads()]
ψnk_real_chunklocal = Array{complex(T),3}[zeros(complex(T), basis.fft_size)
for _ = 1:Threads.nthreads()]
ρ_chunklocal = [convert(array_type(basis), zeros(T, basis.fft_size..., basis.model.n_spin_components))
for _ = 1:Threads.nthreads()]
ψnk_real_chunklocal = [convert(array_type(basis), zeros(complex(T), basis.fft_size))
for _ = 1:Threads.nthreads()]

@sync for (ichunk, chunk) in enumerate(Iterators.partition(ik_n, chunk_length))
Threads.@spawn for (ik, n) in chunk # spawn a task per chunk
kpt = basis.kpoints[ik]
ψnk_real = ψnk_real_chunklocal[ichunk]
ρ_loc = ρ_chunklocal[ichunk]

kpt = basis.kpoints[ik]
G_to_r!(ψnk_real, basis, kpt, ψ[ik][:, n])
G_to_r!(ψnk_real, basis, kpt, ψ[ik][:, n])
ρ_loc[:, :, :, kpt.spin] .+= occupation[ik][n] .* basis.kweights[ik] .* abs2.(ψnk_real)
end
end
Expand Down
5 changes: 4 additions & 1 deletion src/eigen/diag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ function diagonalize_all_kblocks(eigensolver, ham::Hamiltonian, nev_per_kpoint::
end

# Transform results into a nicer datastructure
(λ=[real.(res.λ) for res in results],
# TODO: keep λ on the gpu? Careful then, as self_consistent_field's eigenvalues
# will be a CuArray -> due to the Smearing.occupation function, occupation will also
# be a CuArray, so no scalar indexing (in ene_ops, in compute_density...)
(λ=[Array(real.(res.λ)) for res in results],
X=[res.X for res in results],
residual_norms=[res.residual_norms for res in results],
iterations=[res.iterations for res in results],
Expand Down
1 change: 1 addition & 0 deletions src/eigen/diag_lobpcg_hyper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ function lobpcg_hyper(A, X0; maxiter=100, prec=nothing,
result = LOBPCG(A, X0, I, prec, tol, maxiter; n_conv_check=n_conv_check, kwargs...)

n_conv_check === nothing && (n_conv_check = size(X0, 2))

converged = maximum(result.residual_norms[1:n_conv_check]) < tol
iterations = size(result.residual_history, 2) - 1

Expand Down
Loading