Skip to content

Commit a1e7ec5

Browse files
authored
[NDTensors] Fix scalar indexing issue for Diag broadcast on GPU (#1497)
1 parent 984d814 commit a1e7ec5

File tree

11 files changed

+117
-44
lines changed

11 files changed

+117
-44
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
module NDTensorsGPUArraysCoreExt
22
include("contract.jl")
3+
include("blocksparsetensor.jl")
34
end
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using GPUArraysCore: @allowscalar, AbstractGPUArray
2+
using NDTensors: NDTensors, BlockSparseTensor, dense, diag, map_diag!
3+
using NDTensors.DiagonalArrays: diaglength
4+
using NDTensors.Expose: Exposed, unexpose
5+
6+
## TODO to circumvent issues with blocksparse and scalar indexing
7+
## convert blocksparse GPU tensors to dense tensors and call diag
8+
## copying will probably have some impact on timing but this code
9+
## currently isn't used in the main code, just in tests.
10+
function NDTensors.diag(ETensor::Exposed{<:AbstractGPUArray,<:BlockSparseTensor})
11+
return diag(dense(unexpose(ETensor)))
12+
end
13+
14+
## TODO scalar indexing is slow here
15+
function NDTensors.map_diag!(
16+
f::Function,
17+
exposed_t_destination::Exposed{<:AbstractGPUArray,<:BlockSparseTensor},
18+
exposed_t_source::Exposed{<:AbstractGPUArray,<:BlockSparseTensor},
19+
)
20+
t_destination = unexpose(exposed_t_destination)
21+
t_source = unexpose(exposed_t_source)
22+
@allowscalar for i in 1:diaglength(t_destination)
23+
NDTensors.setdiagindex!(t_destination, f(NDTensors.getdiagindex(t_source, i)), i)
24+
end
25+
return t_destination
26+
end

NDTensors/src/blocksparse/blocksparsetensor.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ end
256256
# Returns the offset of the new block added.
257257
# XXX rename to insertblock!, no need to return offset
258258
using .TypeParameterAccessors: unwrap_array_type
259-
using .Expose: expose
259+
using .Expose: Exposed, expose, unexpose
260260
function insertblock_offset!(T::BlockSparseTensor{ElT,N}, newblock::Block{N}) where {ElT,N}
261261
newdim = blockdim(T, newblock)
262262
newoffset = nnz(T)
@@ -356,6 +356,30 @@ function dense(T::TensorT) where {TensorT<:BlockSparseTensor}
356356
return tensor(Dense(r), inds(T))
357357
end
358358

359+
function diag(ETensor::Exposed{<:AbstractArray,<:BlockSparseTensor})
360+
tensor = unexpose(ETensor)
361+
tensordiag = NDTensors.similar(
362+
dense(typeof(tensor)), eltype(tensor), (diaglength(tensor),)
363+
)
364+
for j in 1:diaglength(tensor)
365+
@inbounds tensordiag[j] = getdiagindex(tensor, j)
366+
end
367+
return tensordiag
368+
end
369+
370+
## TODO currently this fails on GPU with scalar indexing
371+
function map_diag!(
372+
f::Function,
373+
exposed_t_destination::Exposed{<:AbstractArray,<:BlockSparseTensor},
374+
exposed_t_source::Exposed{<:AbstractArray,<:BlockSparseTensor},
375+
)
376+
t_destination = unexpose(exposed_t_destination)
377+
t_source = unexpose(exposed_t_source)
378+
for i in 1:diaglength(t_destination)
379+
NDTensors.setdiagindex!(t_destination, f(NDTensors.getdiagindex(t_source, i)), i)
380+
end
381+
return t_destination
382+
end
359383
#
360384
# Operations
361385
#

NDTensors/src/dense/densetensor.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ convert(::Type{Array}, T::DenseTensor) = reshape(data(storage(T)), dims(inds(T))
6868
# Useful for using Base Array functions
6969
array(T::DenseTensor) = convert(Array, T)
7070

71+
using .DiagonalArrays: DiagonalArrays, diagview
72+
73+
function DiagonalArrays.diagview(T::DenseTensor)
74+
return diagview(array(T))
75+
end
76+
7177
function Array{ElT,N}(T::DenseTensor{ElT,N}) where {ElT,N}
7278
return copy(array(T))
7379
end

NDTensors/src/diag/diagtensor.jl

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using .DiagonalArrays: diaglength
1+
using .DiagonalArrays: diaglength, diagview
22

33
const DiagTensor{ElT,N,StoreT,IndsT} = Tensor{ElT,N,StoreT,IndsT} where {StoreT<:Diag}
44
const NonuniformDiagTensor{ElT,N,StoreT,IndsT} =
@@ -9,9 +9,7 @@ const UniformDiagTensor{ElT,N,StoreT,IndsT} =
99
function diag(tensor::DiagTensor)
1010
tensor_diag = NDTensors.similar(dense(typeof(tensor)), (diaglength(tensor),))
1111
# TODO: Define `eachdiagindex`.
12-
for j in 1:diaglength(tensor)
13-
tensor_diag[j] = getdiagindex(tensor, j)
14-
end
12+
diagview(tensor_diag) .= diagview(tensor)
1513
return tensor_diag
1614
end
1715

@@ -33,6 +31,10 @@ function Array(T::DiagTensor{ElT,N}) where {ElT,N}
3331
return Array{ElT,N}(T)
3432
end
3533

34+
function DiagonalArrays.diagview(T::NonuniformDiagTensor)
35+
return data(T)
36+
end
37+
3638
function zeros(tensortype::Type{<:DiagTensor}, inds)
3739
return tensor(generic_zeros(storagetype(tensortype), mindim(inds)), inds)
3840
end
@@ -110,32 +112,11 @@ end
110112
using .TypeParameterAccessors: unwrap_array_type
111113
# convert to Dense
112114
function dense(T::DiagTensor)
113-
return dense(unwrap_array_type(T), T)
114-
end
115-
116-
# CPU version
117-
function dense(::Type{<:Array}, T::DiagTensor)
118115
R = zeros(dense(typeof(T)), inds(T))
119-
for i in 1:diaglength(T)
120-
setdiagindex!(R, getdiagindex(T, i), i)
121-
end
116+
diagview(R) .= diagview(T)
122117
return R
123118
end
124119

125-
# GPU version
126-
function dense(::Type{<:AbstractArray}, T::DiagTensor)
127-
D_cpu = dense(Array, cpu(T))
128-
return adapt(unwrap_array_type(T), D_cpu)
129-
end
130-
131-
# UniformDiag version
132-
# TODO: Delete once new DiagonalArray is designed.
133-
# TODO: This creates a tensor on CPU by default so may cause
134-
# problems for GPU.
135-
function dense(::Type{<:Number}, T::DiagTensor)
136-
return dense(Tensor(Diag(fill(getdiagindex(T, 1), diaglength(T))), inds(T)))
137-
end
138-
139120
denseblocks(T::DiagTensor) = dense(T)
140121

141122
function permutedims!(
@@ -145,16 +126,14 @@ function permutedims!(
145126
f::Function=(r, t) -> t,
146127
) where {N}
147128
# TODO: check that inds(R)==permute(inds(T),perm)?
148-
for i in 1:diaglength(R)
149-
@inbounds setdiagindex!(R, f(getdiagindex(R, i), getdiagindex(T, i)), i)
150-
end
129+
diagview(R) .= f.(diagview(R), diagview(T))
151130
return R
152131
end
153132

154133
function permutedims(
155134
T::DiagTensor{<:Number,N}, perm::NTuple{N,Int}, f::Function=identity
156135
) where {N}
157-
R = NDTensors.similar(T, permute(inds(T), perm))
136+
R = NDTensors.similar(T)
158137
g(r, t) = f(t)
159138
permutedims!(R, T, perm, g)
160139
return R
@@ -193,9 +172,7 @@ end
193172
function permutedims!(
194173
R::DenseTensor{ElR,N}, T::DiagTensor{ElT,N}, perm::NTuple{N,Int}, f::Function=(r, t) -> t
195174
) where {ElR,ElT,N}
196-
for i in 1:diaglength(T)
197-
@inbounds setdiagindex!(R, f(getdiagindex(R, i), getdiagindex(T, i)), i)
198-
end
175+
diagview(R) .= f.(diagview(R), diagview(T))
199176
return R
200177
end
201178

NDTensors/src/linearalgebra/linearalgebra.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,6 @@ matrix is unique. Returns a tuple (Q,R).
369369
function qr_positive(M::AbstractMatrix)
370370
sparseQ, R = qr(M)
371371
Q = convert(typeof(R), sparseQ)
372-
nc = size(Q, 2)
373372
signs = nonzero_sign.(diag(R))
374373
Q = Q * Diagonal(signs)
375374
R = Diagonal(conj.(signs)) * R

NDTensors/src/tensor/tensor.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,16 +361,18 @@ function getdiagindex(T::Tensor{<:Number,N}, ind::Int) where {N}
361361
return getindex(T, CartesianIndex(ntuple(_ -> ind, Val(N))))
362362
end
363363

364+
using .Expose: Exposed, expose, unexpose
364365
# TODO: add support for off-diagonals, return
365366
# block sparse vector instead of dense.
366-
function diag(tensor::Tensor)
367+
diag(tensor::Tensor) = diag(expose(tensor))
368+
369+
function diag(ETensor::Exposed)
370+
tensor = unexpose(ETensor)
367371
## d = NDTensors.similar(T, ElT, (diaglength(T),))
368372
tensordiag = NDTensors.similar(
369373
dense(typeof(tensor)), eltype(tensor), (diaglength(tensor),)
370374
)
371-
for n in 1:diaglength(tensor)
372-
tensordiag[n] = tensor[n, n]
373-
end
375+
array(tensordiag) .= diagview(tensor)
374376
return tensordiag
375377
end
376378

@@ -384,6 +386,12 @@ function setdiagindex!(T::Tensor{<:Number,N}, val, ind::Int) where {N}
384386
return T
385387
end
386388

389+
function map_diag!(f::Function, exposed_t_destination::Exposed, exposed_t_source::Exposed)
390+
diagview(unexpose(exposed_t_destination)) .= f.(diagview(unexpose(exposed_t_source)))
391+
return unexpose(exposed_t_destination)
392+
end
393+
map_diag(f::Function, t::Tensor) = map_diag!(f, expose(copy(t)), expose(t))
394+
387395
#
388396
# Some generic contraction functionality
389397
#

NDTensors/test/test_blocksparse.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ using NDTensors:
1010
blockview,
1111
data,
1212
dense,
13+
diag,
14+
diaglength,
1315
dims,
1416
eachnzblock,
1517
inds,
@@ -52,6 +54,8 @@ using Test: @test, @test_throws, @testset
5254
@test isblocknz(A, (1, 2))
5355
@test !isblocknz(A, (1, 1))
5456
@test !isblocknz(A, (2, 2))
57+
dA = diag(A)
58+
@test @allowscalar dA diag(dense(A))
5559

5660
# Test different ways of getting nnz
5761
@test nnz(blockoffsets(A), inds(A)) == nnz(A)
@@ -104,6 +108,10 @@ using Test: @test, @test_throws, @testset
104108
@allowscalar for I in eachindex(C)
105109
@test C[I] == A[I] + B[I]
106110
end
111+
Cp = NDTensors.map_diag(i -> 2 * i, C)
112+
@allowscalar for i in 1:diaglength(Cp)
113+
@test Cp[i, i] == 2 * C[i, i]
114+
end
107115

108116
Ap = permutedims(A, (2, 1))
109117

NDTensors/test/test_dense.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ NDTensors.dim(i::MyInd) = i.dim
4848
randn!(B)
4949
C = copy(A)
5050
C = permutedims!!(C, B, (1, 2), +)
51+
Cp = NDTensors.map_diag(i -> 2 * i, C)
52+
@allowscalar for i in 1:diaglength(Cp)
53+
@test Cp[i, i] == 2 * C[i, i]
54+
end
5155

5256
Ap = permutedims(A, (2, 1))
5357
@allowscalar begin

NDTensors/test/test_diag.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,39 @@ using LinearAlgebra: dot
3434
D = Tensor(Diag(1), (2, 2))
3535
@test norm(D) == 2
3636
d = 3
37+
## TODO this fails because uniform diag tensors are immutable
38+
#S = NDTensors.map_diag((i->i * 2), dev(D))
39+
# @allowscalar for i in 1:diaglength(S)
40+
# @test S[i,i] == 2.0 * D[i,i]
41+
# end
42+
3743
vr = rand(elt, d)
3844
D = dev(tensor(Diag(vr), (d, d)))
3945
Da = Array(D)
4046
Dm = Matrix(D)
47+
Da = permutedims(D, (2, 1))
4148
@allowscalar begin
4249
@test Da == NDTensors.LinearAlgebra.diagm(0 => vr)
4350
@test Da == NDTensors.LinearAlgebra.diagm(0 => vr)
4451

45-
## TODO Currently this permutedims requires scalar indexing on GPU.
46-
Da = permutedims(D, (2, 1))
4752
@test Da == D
4853
end
4954

55+
# This if statement corresponds to the reported bug:
56+
# https://github.com/JuliaGPU/Metal.jl/issues/364
57+
if !(dev == NDTensors.mtl && elt === ComplexF32)
58+
S = permutedims(dev(D), (1, 2), sqrt)
59+
@allowscalar begin
60+
for i in 1:diaglength(S)
61+
@test S[i, i] sqrt(D[i, i])
62+
end
63+
end
64+
end
65+
S = NDTensors.map_diag(i -> 2 * i, dev(D))
66+
@allowscalar for i in 1:diaglength(S)
67+
@test S[i, i] == 2 * D[i, i]
68+
end
69+
5070
# Regression test for https://github.com/ITensor/ITensors.jl/issues/1199
5171
S = dev(tensor(Diag(randn(elt, 2)), (2, 2)))
5272
## This was creating a `Dense{ReshapedArray{Adjoint{Matrix}}}` which, in mul!, was

jenkins/Jenkinsfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pipeline {
2727
}
2828
steps {
2929
sh '''
30-
julia -e 'using Pkg; Pkg.activate(temp=true); Pkg.develop(path="./NDTensors"); Pkg.develop(path="."); Pkg.test("NDTensors"; test_args=["cuda"])'
30+
julia -e 'using Pkg; Pkg.Registry.update(); Pkg.update(); Pkg.activate(temp=true); Pkg.develop(path="./NDTensors"); Pkg.develop(path="."); Pkg.test("NDTensors"; test_args=["cuda"])'
3131
'''
3232
}
3333
}
@@ -51,7 +51,7 @@ pipeline {
5151
}
5252
steps {
5353
sh '''
54-
julia -e 'using Pkg; Pkg.activate(temp=true); Pkg.develop(path="./NDTensors"); Pkg.develop(path="."); Pkg.test("NDTensors"; test_args=["cuda"])'
54+
julia -e 'using Pkg; Pkg.Registry.update(); Pkg.update(); Pkg.activate(temp=true); Pkg.develop(path="./NDTensors"); Pkg.develop(path="."); Pkg.test("NDTensors"; test_args=["cuda"])'
5555
'''
5656
}
5757
}
@@ -75,7 +75,7 @@ pipeline {
7575
}
7676
steps {
7777
sh '''
78-
julia -e 'using Pkg; Pkg.activate(temp=true); Pkg.develop(path="./NDTensors"); Pkg.develop(path="."); Pkg.test("NDTensors"; test_args=["cutensor"])'
78+
julia -e 'using Pkg; Pkg.Registry.update(); Pkg.update(); Pkg.activate(temp=true); Pkg.develop(path="./NDTensors"); Pkg.develop(path="."); Pkg.test("NDTensors"; test_args=["cutensor"])'
7979
'''
8080
}
8181
}

0 commit comments

Comments
 (0)