Skip to content

[BlockSparseArrays] Use new SparseArrayDOK type in BlockSparseArrays #1272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6401807
[NDTensors] Start SparseArrayDOKs module
mtfishman Nov 25, 2023
9e4142c
Reorganization
mtfishman Nov 26, 2023
0677dfa
Format
mtfishman Nov 26, 2023
995c3f9
Start making SparseArrayInterface module, to be used by SparseArrayDO…
mtfishman Nov 28, 2023
b696a7e
Update interface
mtfishman Nov 28, 2023
e984cbe
Add another test
mtfishman Nov 28, 2023
475ee31
Improved map
mtfishman Nov 28, 2023
36961f8
New SparseArrayDOKs using SparseArrayInterface
mtfishman Nov 28, 2023
25f8229
Fix namespace issues
mtfishman Nov 29, 2023
3b3c50d
One more namespace issue
mtfishman Nov 29, 2023
a20e819
More namespace issues
mtfishman Nov 29, 2023
d571bd9
Julia 1.6 backwards compatibility
mtfishman Nov 29, 2023
410df32
Use SparseArrayInterface in DiagonalArrays
mtfishman Nov 29, 2023
41c9816
Format
mtfishman Nov 29, 2023
420692a
Fix loading issue
mtfishman Nov 29, 2023
c6dcefe
Missing include, improve README
mtfishman Nov 29, 2023
2478166
[BlockSparseArrays] Start using SparseArrayDOK
mtfishman Nov 29, 2023
43929be
Small fixes
mtfishman Nov 30, 2023
f70a520
Merge branch 'main' into NDTensors_new_BlockSparseArrays
mtfishman Nov 30, 2023
41b92b8
Change SparseArray to SparseArrayDOK
mtfishman Nov 30, 2023
af5ff02
Format
mtfishman Nov 30, 2023
a1733f6
Temporarily remove broken tests
mtfishman Nov 30, 2023
1be1d00
Introduct AbstractSparseArray, start rewriting BlockSparseArray
mtfishman Dec 1, 2023
97f3df4
Move AbstractSparseArray to SparseArrayInterface
mtfishman Dec 1, 2023
32d375a
Improve testing and organization
mtfishman Dec 1, 2023
0ea3eee
DiagonalArrays reorganization and simplification
mtfishman Dec 1, 2023
f918aee
Get more BlockSparseArrays tests passing
mtfishman Dec 1, 2023
fc0ff14
Move arraytensor code to backup files
mtfishman Dec 1, 2023
8487d0c
Move arraystorage code to backup files
mtfishman Dec 1, 2023
fc9ff82
Try fixing tests
mtfishman Dec 1, 2023
b5b643d
Comment
mtfishman Dec 1, 2023
796f33d
Merge main
mtfishman Dec 1, 2023
4d1453d
Fix namespace issue
mtfishman Dec 1, 2023
1f04d9c
Remove arraytensor test
mtfishman Dec 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore="46192b85-c4d5-4398-a991-12ede77f4527"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -25,6 +25,7 @@ Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
46 changes: 0 additions & 46 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,52 +137,6 @@ include("empty/EmptyTensor.jl")
include("empty/tensoralgebra/contract.jl")
include("empty/adapt.jl")

#####################################
# Array Tensor (experimental)
#

# TODO: Turn this into a module `CombinerArray`.
include("arraystorage/combiner/storage/combinerarray.jl")

include("arraystorage/arraystorage/storage/arraystorage.jl")
include("arraystorage/arraystorage/storage/conj.jl")
include("arraystorage/arraystorage/storage/permutedims.jl")
include("arraystorage/arraystorage/storage/contract.jl")

include("arraystorage/arraystorage/tensor/arraystorage.jl")
include("arraystorage/arraystorage/tensor/zeros.jl")
include("arraystorage/arraystorage/tensor/indexing.jl")
include("arraystorage/arraystorage/tensor/permutedims.jl")
include("arraystorage/arraystorage/tensor/mul.jl")
include("arraystorage/arraystorage/tensor/contract.jl")
include("arraystorage/arraystorage/tensor/qr.jl")
include("arraystorage/arraystorage/tensor/eigen.jl")
include("arraystorage/arraystorage/tensor/svd.jl")

# DiagonalArray storage
include("arraystorage/diagonalarray/storage/contract.jl")

include("arraystorage/diagonalarray/tensor/contract.jl")

# BlockSparseArray storage
include("arraystorage/blocksparsearray/storage/unwrap.jl")
include("arraystorage/blocksparsearray/storage/contract.jl")

include("arraystorage/blocksparsearray/tensor/contract.jl")

# Combiner storage
include("arraystorage/combiner/storage/promote_rule.jl")
include("arraystorage/combiner/storage/contract_utils.jl")
include("arraystorage/combiner/storage/contract.jl")

include("arraystorage/combiner/tensor/to_arraystorage.jl")
include("arraystorage/combiner/tensor/contract.jl")

include("arraystorage/blocksparsearray/storage/combiner/contract.jl")
include("arraystorage/blocksparsearray/storage/combiner/contract_utils.jl")
include("arraystorage/blocksparsearray/storage/combiner/contract_combine.jl")
include("arraystorage/blocksparsearray/storage/combiner/contract_uncombine.jl")

#####################################
# Deprecations
#
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# TODO: Change to:
# using .SparseArrayDOKs: SparseArrayDOK
using .BlockSparseArrays: SparseArray
# TODO: Define in `SparseArrayInterface`.
using ..SparseArrayDOKs: SparseArrayDOK

# TODO: This is inefficient, need to optimize.
# Look at `contract_labels`, `contract_blocks` and `maybe_contract_blocks!` in:
Expand Down Expand Up @@ -39,11 +38,11 @@ function default_contract_muladd(a1, labels1, a2, labels2, a_dest, labels_dest)
end

function contract!(
a_dest::SparseArray,
a_dest::SparseArrayDOK,
labels_dest,
a1::SparseArray,
a1::SparseArrayDOK,
labels1,
a2::SparseArray,
a2::SparseArrayDOK,
labels2;
muladd=default_contract_muladd,
)
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/diag/diagtensor.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using .DiagonalArrays: diaglength

const DiagTensor{ElT,N,StoreT,IndsT} = Tensor{ElT,N,StoreT,IndsT} where {StoreT<:Diag}
const NonuniformDiagTensor{ElT,N,StoreT,IndsT} =
Tensor{ElT,N,StoreT,IndsT} where {StoreT<:NonuniformDiag}
Expand Down
4 changes: 4 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/examples/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
59 changes: 38 additions & 21 deletions NDTensors/src/lib/BlockSparseArrays/examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
# to store non-zero values, specifically a `Dictionary` from `Dictionaries.jl`.
# `BlockArrays` reinterprets the `SparseArray` as a blocked data structure.

using NDTensors.BlockSparseArrays
using BlockArrays: BlockArrays, blockedrange
using Test
using BlockArrays: BlockArrays, PseudoBlockVector, blockedrange
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
using Test: @test, @test_broken

function main()
Block = BlockArrays.Block

## Block dimensions
i1 = [2, 3]
i2 = [2, 3]
Expand All @@ -22,39 +24,54 @@ function main()
end

## Data
nz_blocks = BlockArrays.Block.([(1, 1), (2, 2)])
nz_blocks = Block.([(1, 1), (2, 2)])
nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks]
nz_block_lengths = prod.(nz_block_sizes)

## Blocks with contiguous underlying data
d_data = PseudoBlockVector(randn(sum(nz_block_lengths)), nz_block_lengths)
d_blocks = [
reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for
i in 1:length(nz_blocks)
]
b = BlockSparseArray(nz_blocks, d_blocks, i_axes)

@test block_nstored(b) == 2

## Blocks with discontiguous underlying data
d_blocks = randn.(nz_block_sizes)
b = BlockSparseArray(nz_blocks, d_blocks, i_axes)

## Blocks with contiguous underlying data
## d_data = PseudoBlockVector(randn(sum(nz_block_lengths)), nz_block_lengths)
## d_blocks = [reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for i in 1:length(nz_blocks)]

B = BlockSparseArray(nz_blocks, d_blocks, i_axes)
@test block_nstored(b) == 2

## Access a block
B[BlockArrays.Block(1, 1)]
@test b[Block(1, 1)] == d_blocks[1]

## Access a non-zero block, returns a zero matrix
B[BlockArrays.Block(1, 2)]
## Access a zero block, returns a zero matrix
@test b[Block(1, 2)] == zeros(2, 3)

## Set a zero block
B[BlockArrays.Block(1, 2)] = randn(2, 3)
a₁₂ = randn(2, 3)
b[Block(1, 2)] = a₁₂
@test b[Block(1, 2)] == a₁₂

## Matrix multiplication (not optimized for sparsity yet)
@test B * B ≈ Array(B) * Array(B)
@test b * b ≈ Array(b) * Array(b)

permuted_B = permutedims(B, (2, 1))
@test permuted_B isa BlockSparseArray
@test permuted_B == permutedims(Array(B), (2, 1))
permuted_b = permutedims(b, (2, 1))
## TODO: Fix this, broken.
@test_broken permuted_b isa BlockSparseArray
@test permuted_b == permutedims(Array(b), (2, 1))

@test B + B ≈ Array(B) + Array(B)
@test 2B ≈ 2Array(B)
@test b + b ≈ Array(b) + Array(b)

@test reshape(B, ([4, 6, 6, 9],)) isa BlockSparseArray{<:Any,1}
scaled_b = 2b
@test scaled_b ≈ 2Array(b)
## TODO: Fix this, broken.
@test_broken scaled_b isa BlockSparseArray

## TODO: Fix this, broken.
@test_broken reshape(b, ([4, 6, 6, 9],)) isa BlockSparseArray{<:Any,1}

return nothing
end
Expand All @@ -63,8 +80,8 @@ main()

# # BlockSparseArrays.jl and BlockArrays.jl interface

using NDTensors.BlockSparseArrays
using BlockArrays: BlockArrays
using NDTensors.BlockSparseArrays: BlockSparseArray

i1 = [2, 3]
i2 = [2, 3]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using BlockArrays: Block
using Dictionaries: Dictionary, Indices

# TODO: Use `Tuple` conversion once
# BlockArrays.jl PR is merged.
block_to_cartesianindex(b::Block) = CartesianIndex(b.n)

function blocks_to_cartesianindices(i::Indices{<:Block})
return block_to_cartesianindex.(i)
end

function blocks_to_cartesianindices(d::Dictionary{<:Block})
return Dictionary(blocks_to_cartesianindices(eachindex(d)), d)
end
48 changes: 6 additions & 42 deletions NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,8 @@
module BlockSparseArrays
using ..AlgorithmSelection: Algorithm, @Algorithm_str
using BlockArrays:
AbstractBlockArray,
BlockArrays,
BlockVector,
Block,
BlockIndex,
BlockRange,
BlockedUnitRange,
findblockindex,
block,
blockaxes,
blockcheckbounds,
blockfirsts,
blocklasts,
blocklength,
blocklengths,
blockedrange,
blocks
using Compat: Returns, allequal
using Dictionaries: Dictionary, Indices, getindices, set! # TODO: Move to `SparseArraysExtensions`.
using LinearAlgebra: Hermitian
using SplitApplyCombine: groupcount

export BlockSparseArray, SparseArray

include("tensor_product.jl")
include("base.jl")
include("axes.jl")
include("abstractarray.jl")
include("permuteddimsarray.jl")
include("blockarrays.jl")
# TODO: Split off into `SparseArraysExtensions` module, rename to `SparseArrayDOK`.
include("sparsearray.jl")
include("blocksparsearray.jl")
include("allocate_output.jl")
include("subarray.jl")
include("broadcast.jl")
include("fusedims.jl")
include("gradedrange.jl")
include("LinearAlgebraExt/LinearAlgebraExt.jl")

include("blocksparsearrayinterface/blocksparsearrayinterface.jl")
include("blocksparsearrayinterface/blockzero.jl")
include("abstractblocksparsearray/abstractblocksparsearray.jl")
include("blocksparsearray/defaults.jl")
include("blocksparsearray/blocksparsearray.jl")
include("BlockArraysExtensions/BlockArraysExtensions.jl")
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using BlockArrays: BlockArrays, AbstractBlockArray, Block, BlockIndex

# TODO: Delete this. This function was replaced
# by `nstored` but is still used in `NDTensors`.
function nonzero_keys end

abstract type AbstractBlockSparseArray{T,N} <: AbstractBlockArray{T,N} end

# Base `AbstractArray` interface
Base.axes(::AbstractBlockSparseArray) = error("Not implemented")

# BlockArrays `AbstractBlockArray` interface
BlockArrays.blocks(::AbstractBlockSparseArray) = error("Not implemented")

blocktype(a::AbstractBlockSparseArray) = eltype(blocks(a))

# Base `AbstractArray` interface
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N}
return blocksparse_getindex(a, I...)
end

function Base.setindex!(
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Int,N}
) where {N}
blocksparse_setindex!(a, value, I...)
return a
end

function Base.setindex!(
a::AbstractBlockSparseArray{<:Any,N}, value, I::BlockIndex{N}
) where {N}
blocksparse_setindex!(a, value, I)
return a
end

function Base.setindex!(a::AbstractBlockSparseArray{<:Any,N}, value, I::Block{N}) where {N}
blocksparse_setindex!(a, value, I)
return a
end

# `BlockArrays` interface
function BlockArrays.viewblock(
a::AbstractBlockSparseArray{<:Any,N}, I::Block{N,Int}
) where {N}
return blocksparse_viewblock(a, I)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
module BlockSparseArrays
using ..AlgorithmSelection: Algorithm, @Algorithm_str
using BlockArrays:
AbstractBlockArray,
BlockArrays,
BlockVector,
Block,
BlockIndex,
BlockRange,
BlockedUnitRange,
findblockindex,
block,
blockaxes,
blockcheckbounds,
blockfirsts,
blocklasts,
blocklength,
blocklengths,
blockedrange,
blocks
using Compat: Returns, allequal
using Dictionaries: Dictionary, Indices, getindices, set! # TODO: Move to `SparseArraysExtensions`.
using LinearAlgebra: Hermitian
using SplitApplyCombine: groupcount

export BlockSparseArray # , SparseArray

include("defaults.jl")
include("tensor_product.jl")
include("base.jl")
include("axes.jl")
include("abstractarray.jl")
include("permuteddimsarray.jl")
include("blockarrays.jl")
# TODO: Split off into `SparseArraysExtensions` module, rename to `SparseArrayDOK`.
# include("sparsearray.jl")
include("blocksparsearray.jl")
include("allocate_output.jl")
include("subarray.jl")
include("broadcast.jl")
include("fusedims.jl")
include("gradedrange.jl")
include("LinearAlgebraExt/LinearAlgebraExt.jl")

end
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LinearAlgebraExt
using ...AlgorithmSelection: Algorithm, @Algorithm_str
using BlockArrays: BlockArrays, blockedrange, blocks
using ..BlockSparseArrays: SparseArray, nonzero_keys # TODO: Move to `SparseArraysExtensions` module, rename `SparseArrayDOK`.
using ..BlockSparseArrays: nonzero_keys # TODO: Move to `SparseArraysExtensions` module, rename `SparseArrayDOK`.
using ..BlockSparseArrays: BlockSparseArrays, BlockSparseArray, nonzero_blockkeys
using LinearAlgebra: LinearAlgebra, Hermitian, Transpose, I, eigen, qr
using SparseArrays: SparseArrays, SparseMatrixCSC, spzeros, sparse
Expand Down
Loading