Skip to content

Commit 0579a8c

Browse files
committed
fix stuff and add tests
1 parent 5de03b9 commit 0579a8c

File tree

5 files changed

+286
-32
lines changed

5 files changed

+286
-32
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
Low-level utilities for PyTorch sparse tensors and operations.
88

99
## Introduction
10-
PyTorch's implementation of sparse tensors is lacking full support for many common operations. This repository contains a set of utilities for making PyTorch sparse tensors into more usable general-purpose sparse data structures.
10+
PyTorch's implementation of sparse tensors is lacking full support for many common operations. This repository contains a set of utilities for making PyTorch sparse tensors into more usable general-purpose sparse data structures, particularly in the context of modern neural network architectures like Transformer-based models.
1111

1212
For example, while the basic operation `index_select` has a sparse forward implementation, using it as part of an autograd graph alongside direct manipulation of the sparse tensor's values is not supported:
1313
```python

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,5 @@ exclude_lines = [
5454
"raise NotImplementedError",
5555
"if __name__ == .__main__.:",
5656
"if TYPE_CHECKING:",
57-
"if torch.jit.is_scriping():",
57+
"if torch.jit.is_scripting():",
5858
]

pytorch_sparse_utils/batching/batch_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def concatenated_to_padded(
399399
>>> padded.shape
400400
torch.Size([3, 4, 3, 4]) # 3 batches, max length 4, 3x4 features
401401
"""
402-
validate_atleast_nd(tensor, 2)
402+
validate_atleast_nd(tensor, 1)
403403
if not batch_offsets.ndim == 1:
404404
raise ValueError(f"Expected batch_offsets to be 1D, got {batch_offsets.ndim}")
405405

@@ -418,9 +418,7 @@ def concatenated_to_padded(
418418

419419
# Fast path: If all sequences are equal length can just return a view
420420
if torch.all(seq_lens == max_len):
421-
if not tensor.is_contiguous():
422-
tensor = tensor.contiguous()
423-
out = tensor.view(out_shape)
421+
out = tensor.reshape(out_shape)
424422
padding_mask = torch.zeros(
425423
batch_size, max_len, device=tensor.device, dtype=torch.bool
426424
)
@@ -511,7 +509,7 @@ def padded_to_concatenated(
511509
>>> offsets
512510
tensor([0, 0, 1])
513511
"""
514-
validate_atleast_nd(tensor, 3)
512+
validate_atleast_nd(tensor, 2)
515513
batch_size, max_len = tensor.shape[:2]
516514
feature_dims = tensor.shape[2:]
517515

@@ -656,7 +654,7 @@ def batch_offsets_from_sparse_tensor_indices(indices_tensor: Tensor) -> Tensor:
656654
"""
657655
assert not torch.is_floating_point(indices_tensor)
658656

659-
if indices_tensor.shape[1] == 0: # empty case
657+
if indices_tensor.numel() == 0: # empty case
660658
return torch.zeros(1, device=indices_tensor.device, dtype=indices_tensor.dtype)
661659

662660
batch_indices = indices_tensor[0]

pytorch_sparse_utils/imports.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from MinkowskiEngine.MinkowskiNonlinearity import MinkowskiNonlinearityBase
99

1010
has_minkowskiengine = True
11-
except ImportError:
11+
except ImportError: # pragma: no cover
1212
class DummyME:
1313
SparseTensor = None
1414
ME = DummyME()
@@ -19,7 +19,7 @@ class DummyME:
1919
from spconv.pytorch import SparseConvTensor
2020

2121
has_spconv = True
22-
except ImportError:
22+
except ImportError: # pragma: no cover
2323
spconv = None
2424
SparseConvTensor = None
2525
has_spconv = False

0 commit comments

Comments
 (0)