diff --git a/src/terms/nonlocal.jl b/src/terms/nonlocal.jl index 7add0bf9e..b981f3457 100644 --- a/src/terms/nonlocal.jl +++ b/src/terms/nonlocal.jl @@ -46,7 +46,7 @@ end (; E, term.ops) end -@timing "forces: nonlocal" function compute_forces(::TermAtomicNonlocal, +@timing "forces: nonlocal" function compute_forces(term::TermAtomicNonlocal, basis::PlaneWaveBasis{TT}, ψ, occupation; kwargs...) where {TT} T = promote_type(TT, real(eltype(ψ[1]))) @@ -62,31 +62,25 @@ end # P(G) = form_factor(G) * structure_factor(G). forces = Vec3{T}[zero(Vec3{T}) for _ = 1:length(model.positions)] + group_offset = 0 # offset for the projection vectors from the TermAtomicNonlocal for group in psp_groups element = model.atoms[first(group)] C = to_device(basis.architecture, build_projection_coefficients(T, element.psp)) for (ik, kpt) in enumerate(basis.kpoints) # We compute the forces from the irreductible BZ; they are symmetrized later. - G_plus_k_cart = to_cpu(Gplusk_vectors_cart(basis, kpt)) G_plus_k = Gplusk_vectors(basis, kpt) occupationk = to_cpu(occupation[ik]) - form_factors = to_device(basis.architecture, - build_projector_form_factors(element.psp, G_plus_k_cart)) # Pre-allocation of large arrays (Noticable performance improvements on # CPU and GPU here) δHψk = similar(ψ[ik]) - P = similar(form_factors) - dPdR = similar(form_factors) - twoπp = similar(form_factors, length(G_plus_k)) - structure_factors = similar(form_factors, length(G_plus_k)) + dPdR = similar(term.ops[ik].P, length(G_plus_k), count_n_proj(element.psp)) + twoπp = similar(dPdR, length(G_plus_k)) + offset = group_offset for idx in group - r = model.positions[idx] - map!(p -> cis2pi(-dot(p, r)), structure_factors, G_plus_k) - P .= structure_factors .* form_factors ./ sqrt(unit_cell_volume) - + P = @view term.ops[ik].P[:, offset + 1:offset + count_n_proj(element.psp)] forces[idx] += map(1:3) do α map!(p -> -2π*im*p[α], twoπp, G_plus_k) dPdR .= twoπp .* P @@ -95,8 +89,10 @@ end 2real(dot(ψ[ik][:, iband], δHψk[:, iband])) for iband=1:size(ψ[ik], 2)) end # α + offset += count_n_proj(element.psp) end # r end # kpt + group_offset += count_n_proj(element.psp) * length(group) end # group mpi_sum!(forces, basis.comm_kpts)