Skip to content

GPU optimized symmetry operations #1097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ include("workarounds/dummy_inplace_fft.jl")
include("workarounds/forwarddiff_rules.jl")

# Optimized generic GPU functions and GPU workarounds
include("gpu/symmetry.jl")
include("gpu/linalg.jl")
include("gpu/gpu_arrays.jl")

Expand Down
8 changes: 0 additions & 8 deletions src/gpu/gpu_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@ using Preferences
# https://github.com/JuliaGPU/CUDA.jl/issues/1565
LinearAlgebra.dot(x::AbstractGPUArray, D::Diagonal, y::AbstractGPUArray) = x' * (D * y)

function lowpass_for_symmetry!(ρ::AbstractGPUArray, basis; symmetries=basis.symmetries)
all(isone, symmetries) && return ρ
# lowpass_for_symmetry! currently uses scalar indexing, so we have to do this very ugly
# thing for cases where ρ sits on a device (e.g. GPU)
ρ_CPU = lowpass_for_symmetry!(to_cpu(ρ), basis; symmetries)
ρ .= to_device(basis.architecture, ρ_CPU)
end

for fun in (:potential_terms, :kernel_terms)
@eval function DftFunctionals.$fun(fun::DispatchFunctional, ρ::AT,
args...) where {AT <: AbstractGPUArray{Float64}}
Expand Down
46 changes: 46 additions & 0 deletions src/gpu/symmetry.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using GPUArraysCore

function accumulate_over_symmetries!(ρaccu::AbstractArray, ρin::AbstractArray,
basis::PlaneWaveBasis{T}, symmetries) where {T}
Comment on lines +3 to +4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, how much worse is this implementation compared to the CPU version ? Can we not just make this the CPU version, too ? It looks like it should not be very much worse than what we have.

Gs = reshape(G_vectors(basis), size(ρaccu))
fft_size = basis.fft_size

symm_invS = to_device(basis.architecture, [Mat3{Int}(inv(symop.S)) for symop in symmetries])
symm_τ = to_device(basis.architecture, [symop.τ for symop in symmetries])
n_symm = length(symmetries)

map!(ρaccu, Gs) do G
acc = zero(complex(T))
# Explicit loop over indicies because AMDGPU does not support zip() in map!
for i_symm in 1:n_symm
invS = symm_invS[i_symm]
τ = symm_τ[i_symm]
idx = index_G_vectors(fft_size, invS * G)
acc += isnothing(idx) ? zero(complex(T)) : cis2pi(-T(dot(G, τ))) * ρin[idx]
end
acc

Check warning on line 21 in src/gpu/symmetry.jl

View check run for this annotation

Codecov / codecov/patch

src/gpu/symmetry.jl#L21

Added line #L21 was not covered by tests
end
ρaccu

Check warning on line 23 in src/gpu/symmetry.jl

View check run for this annotation

Codecov / codecov/patch

src/gpu/symmetry.jl#L23

Added line #L23 was not covered by tests
end

function lowpass_for_symmetry!(ρ::AbstractGPUArray, basis::PlaneWaveBasis{T};

Check warning on line 26 in src/gpu/symmetry.jl

View check run for this annotation

Codecov / codecov/patch

src/gpu/symmetry.jl#L26

Added line #L26 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only place where we need this is in symmetrize_ρ, where it comes right after accumulate_over_symmetries!. I think for the GPU version it would make a lot of sense to fuse these two functions into one and thus have a single map! going over all Gs. This you should be able to do with a boolean flag do_lowpass, which hopefully Julia is smart enough to constant-prop into the GPU kernel and fully compile away if set to false.

Again feel free to make this fused function also the CPU function if this does not hurt performance too much.

symmetries=basis.symmetries) where {T}
all(isone, symmetries) && return ρ

Check warning on line 28 in src/gpu/symmetry.jl

View check run for this annotation

Codecov / codecov/patch

src/gpu/symmetry.jl#L28

Added line #L28 was not covered by tests

Gs = reshape(G_vectors(basis), size(ρ))
fft_size = basis.fft_size
ρtmp = similar(ρ)

Check warning on line 32 in src/gpu/symmetry.jl

View check run for this annotation

Codecov / codecov/patch

src/gpu/symmetry.jl#L30-L32

Added lines #L30 - L32 were not covered by tests

symm_S = to_device(basis.architecture, [symop.S for symop in symmetries])

Check warning on line 34 in src/gpu/symmetry.jl

View check run for this annotation

Codecov / codecov/patch

src/gpu/symmetry.jl#L34

Added line #L34 was not covered by tests

map!(ρtmp, ρ, Gs) do ρ_i, G
acc = ρ_i
for S in symm_S
idx = index_G_vectors(fft_size, S * G)
acc *= isnothing(idx) ? zero(complex(T)) : one(complex(T))
end
Comment on lines +37 to +41
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting off.

acc

Check warning on line 42 in src/gpu/symmetry.jl

View check run for this annotation

Codecov / codecov/patch

src/gpu/symmetry.jl#L36-L42

Added lines #L36 - L42 were not covered by tests
end
ρ .= ρtmp
ρ

Check warning on line 45 in src/gpu/symmetry.jl

View check run for this annotation

Codecov / codecov/patch

src/gpu/symmetry.jl#L44-L45

Added lines #L44 - L45 were not covered by tests
end
4 changes: 2 additions & 2 deletions src/symmetry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,15 @@ Symmetrize a density by applying all the basis (by default) symmetries and formi
"""
@views @timing function symmetrize_ρ(basis, ρ::AbstractArray{T};
symmetries=basis.symmetries, do_lowpass=true) where {T}
ρin_fourier = to_cpu(fft(basis, ρ))
ρin_fourier = fft(basis, ρ)
ρout_fourier = zero(ρin_fourier)
for σ = 1:size(ρ, 4)
accumulate_over_symmetries!(ρout_fourier[:, :, :, σ],
ρin_fourier[:, :, :, σ], basis, symmetries)
do_lowpass && lowpass_for_symmetry!(ρout_fourier[:, :, :, σ], basis; symmetries)
end
inv_fft = T <: Real ? irfft : ifft
inv_fft(basis, to_device(basis.architecture, ρout_fourier) ./ length(symmetries))
inv_fft(basis, ρout_fourier ./ length(symmetries))
end

"""
Expand Down
Loading