diff --git a/src/DFTK.jl b/src/DFTK.jl index ab698b269..63ffabf8f 100644 --- a/src/DFTK.jl +++ b/src/DFTK.jl @@ -234,6 +234,7 @@ include("workarounds/dummy_inplace_fft.jl") include("workarounds/forwarddiff_rules.jl") # Optimized generic GPU functions and GPU workarounds +include("gpu/symmetry.jl") include("gpu/linalg.jl") include("gpu/gpu_arrays.jl") diff --git a/src/gpu/gpu_arrays.jl b/src/gpu/gpu_arrays.jl index 6f7132037..a465cbc7a 100644 --- a/src/gpu/gpu_arrays.jl +++ b/src/gpu/gpu_arrays.jl @@ -5,14 +5,6 @@ using Preferences # https://github.com/JuliaGPU/CUDA.jl/issues/1565 LinearAlgebra.dot(x::AbstractGPUArray, D::Diagonal, y::AbstractGPUArray) = x' * (D * y) -function lowpass_for_symmetry!(ρ::AbstractGPUArray, basis; symmetries=basis.symmetries) - all(isone, symmetries) && return ρ - # lowpass_for_symmetry! currently uses scalar indexing, so we have to do this very ugly - # thing for cases where ρ sits on a device (e.g. GPU) - ρ_CPU = lowpass_for_symmetry!(to_cpu(ρ), basis; symmetries) - ρ .= to_device(basis.architecture, ρ_CPU) -end - for fun in (:potential_terms, :kernel_terms) @eval function DftFunctionals.$fun(fun::DispatchFunctional, ρ::AT, args...) where {AT <: AbstractGPUArray{Float64}} diff --git a/src/gpu/symmetry.jl b/src/gpu/symmetry.jl new file mode 100644 index 000000000..4929106e0 --- /dev/null +++ b/src/gpu/symmetry.jl @@ -0,0 +1,46 @@ +using GPUArraysCore + +function accumulate_over_symmetries!(ρaccu::AbstractArray, ρin::AbstractArray, + basis::PlaneWaveBasis{T}, symmetries) where {T} + Gs = reshape(G_vectors(basis), size(ρaccu)) + fft_size = basis.fft_size + + symm_invS = to_device(basis.architecture, [Mat3{Int}(inv(symop.S)) for symop in symmetries]) + symm_τ = to_device(basis.architecture, [symop.τ for symop in symmetries]) + n_symm = length(symmetries) + + map!(ρaccu, Gs) do G + acc = zero(complex(T)) + # Explicit loop over indicies because AMDGPU does not support zip() in map! + for i_symm in 1:n_symm + invS = symm_invS[i_symm] + τ = symm_τ[i_symm] + idx = index_G_vectors(fft_size, invS * G) + acc += isnothing(idx) ? zero(complex(T)) : cis2pi(-T(dot(G, τ))) * ρin[idx] + end + acc + end + ρaccu +end + +function lowpass_for_symmetry!(ρ::AbstractGPUArray, basis::PlaneWaveBasis{T}; + symmetries=basis.symmetries) where {T} + all(isone, symmetries) && return ρ + + Gs = reshape(G_vectors(basis), size(ρ)) + fft_size = basis.fft_size + ρtmp = similar(ρ) + + symm_S = to_device(basis.architecture, [symop.S for symop in symmetries]) + + map!(ρtmp, ρ, Gs) do ρ_i, G + acc = ρ_i + for S in symm_S + idx = index_G_vectors(fft_size, S * G) + acc *= isnothing(idx) ? zero(complex(T)) : one(complex(T)) + end + acc + end + ρ .= ρtmp + ρ +end diff --git a/src/symmetry.jl b/src/symmetry.jl index 16c0621e1..91556b75f 100644 --- a/src/symmetry.jl +++ b/src/symmetry.jl @@ -320,7 +320,7 @@ Symmetrize a density by applying all the basis (by default) symmetries and formi """ @views @timing function symmetrize_ρ(basis, ρ::AbstractArray{T}; symmetries=basis.symmetries, do_lowpass=true) where {T} - ρin_fourier = to_cpu(fft(basis, ρ)) + ρin_fourier = fft(basis, ρ) ρout_fourier = zero(ρin_fourier) for σ = 1:size(ρ, 4) accumulate_over_symmetries!(ρout_fourier[:, :, :, σ], @@ -328,7 +328,7 @@ Symmetrize a density by applying all the basis (by default) symmetries and formi do_lowpass && lowpass_for_symmetry!(ρout_fourier[:, :, :, σ], basis; symmetries) end inv_fft = T <: Real ? irfft : ifft - inv_fft(basis, to_device(basis.architecture, ρout_fourier) ./ length(symmetries)) + inv_fft(basis, ρout_fourier ./ length(symmetries)) end """