Skip to content

Commit 5fdc2bb

Browse files
committed
update scatter func
1 parent 610c582 commit 5fdc2bb

File tree

2 files changed

+361
-84
lines changed

2 files changed

+361
-84
lines changed

pytorch_sparse_utils/indexing/scatter.py

Lines changed: 79 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,61 @@
88
)
99

1010

11+
# @torch.jit.script
12+
def _merge_sorted(
13+
old_nd: Tensor,
14+
new_nd: Tensor,
15+
old_values: Tensor,
16+
new_values: Tensor,
17+
insertion_positions: Tensor,
18+
) -> tuple[Tensor, Tensor]:
19+
"""Merges two sorted sequences of sparse indices/values and return them in coalesced
20+
order.
21+
22+
All input args must be on the same device.
23+
24+
Args:
25+
old_nd (Tensor): [S x n_old] tensor of N-D indices
26+
new_nd (Tensor): [S x n_new] tensor of N-D indices
27+
old_values (Tensor): [n_old, ...] tensor of values
28+
new_values (Tensor): [n_new, ...] tensor of values
29+
insertion_positions (Tensor): [n_new] tensor of insertion positions in
30+
old_linear for each element in new_linear
31+
32+
Returns:
33+
merged_nd: [S, n_old+n_new]
34+
merged_values: [n_old+n_new, ...]
35+
"""
36+
device = old_nd.device
37+
n_old, n_new = old_nd.size(1), new_nd.size(1)
38+
n_total = n_old + n_new
39+
40+
# determine final positions of new values
41+
# account for previous insertions to get final positions of new rows
42+
new_positions = insertion_positions + torch.arange(
43+
n_new, device=device, dtype=insertion_positions.dtype
44+
)
45+
46+
# determine final positions of old values by counting how many new values are
47+
# inserted before each old value
48+
hist = torch.bincount(insertion_positions, minlength=n_old + 1)
49+
old_shift = torch.cumsum(hist[:-1], 0)
50+
old_positions = torch.arange(n_old, device=device) + old_shift
51+
52+
# allocate output tensors
53+
merged_nd = old_nd.new_empty(old_nd.size(0), n_total)
54+
merged_values = old_values.new_empty((n_total,) + old_values.shape[1:])
55+
56+
# insert values
57+
merged_nd[:, old_positions] = old_nd
58+
merged_nd[:, new_positions] = new_nd
59+
merged_values[old_positions] = old_values
60+
merged_values[new_positions] = new_values
61+
62+
return merged_nd, merged_values
63+
64+
65+
# @torch.jit.script
1166
def scatter_to_sparse_tensor(
1267
sparse_tensor: Tensor,
1368
index_tensor: Tensor,
@@ -21,12 +76,13 @@ def scatter_to_sparse_tensor(
2176
sparse tensor.
2277
2378
Args:
24-
sparse_tensor (Tensor): Sparse tensor of dimension ..., M; where ... are
25-
S leading sparse dimensions and M is the dense dimension
26-
index_tensor (Tensor): Long tensor of dimension ..., S; where ... are
27-
leading batch dimensions.
28-
values (Tensor): Tensor of dimension ..., M; where ... are leading
29-
batch dimensions and M is the dense dimension
79+
sparse_tensor (Tensor): Sparse tensor of dimension
80+
[s0, s1, s2, ..., d0, d1, d2, ...]; where s0, s1, ... are
81+
S leading sparse dimensions and d0, d1, d2, ... are D dense dimensions.
82+
index_tensor (Tensor): Long tensor of dimension [b0, b1, b2, ..., S]; where
83+
b0, b1, b2, ... are B leading batch dimensions.
84+
values (Tensor): Tensor of dimension [b0, b1, b2, ... d0, d1, d2, ...], where
85+
dimensions are as above.
3086
check_all_specified (bool): If True, this function will throw a ValueError
3187
if any of the indices specified in index_tensor are not already present
3288
in sparse_tensor. Default: False.
@@ -105,8 +161,15 @@ def scatter_to_sparse_tensor(
105161
index_tensor = torch.cat(index_tensor.unbind())
106162
values = torch.cat(values.unbind())
107163

108-
assert index_tensor.shape[:-1] == values.shape[:-1]
109-
assert sparse_tensor.dense_dim() == values.ndim - 1
164+
dense_dim = sparse_tensor.dense_dim()
165+
sparse_dim = sparse_tensor.sparse_dim()
166+
values_batch_dims = values.shape[:-dense_dim] if dense_dim else values.shape
167+
if index_tensor.shape[:-1] != values_batch_dims:
168+
raise ValueError(
169+
"Expected matching batch dims for `index_tensor` and `values`, but got "
170+
f"batch dims {index_tensor.shape[:-1]} and "
171+
f"{values_batch_dims}, respectively."
172+
)
110173

111174
sparse_tensor = sparse_tensor.coalesce()
112175
sparse_tensor_values = sparse_tensor.values()
@@ -151,19 +214,24 @@ def scatter_to_sparse_tensor(
151214
new_values = values[~is_specified_mask]
152215

153216
# Get sparse shape info for linearization
154-
sparse_dim = sparse_tensor.sparse_dim()
155217
sparse_sizes = torch.tensor(
156218
sparse_tensor.shape[:sparse_dim], device=sparse_tensor.device
157219
)
158220

221+
if (new_nd_indices >= sparse_sizes.unsqueeze(0)).any():
222+
raise ValueError(
223+
"`index_tensor` has indices that are out of bounds of the original "
224+
f"sparse tensor's sparse shape ({sparse_sizes})."
225+
)
226+
159227
# Obtain linearized versions of all indices for sorting
160228
old_indices_nd = sparse_tensor.indices()
161229
linear_offsets = _make_linear_offsets(sparse_sizes)
162230
new_indices_lin: Tensor = (new_nd_indices * linear_offsets).sum(-1)
163231

164232
# Find duplicate linear indices
165-
unique_new_indices_lin, inverse = new_indices_lin.unique(
166-
sorted=True, return_inverse=True
233+
unique_new_indices_lin, inverse = torch.unique(
234+
new_indices_lin, sorted=True, return_inverse=True
167235
)
168236

169237
# Use inverse of indices unique to write to new values tensor and tensor of
@@ -198,56 +266,3 @@ def scatter_to_sparse_tensor(
198266
device=sparse_tensor.device,
199267
is_coalesced=True,
200268
)
201-
202-
203-
def _merge_sorted(
204-
old_nd: Tensor,
205-
new_nd: Tensor,
206-
old_values: Tensor,
207-
new_values: Tensor,
208-
insertion_positions: Tensor,
209-
) -> tuple[Tensor, Tensor]:
210-
"""Merges two sorted sequences of sparse indices/values and return them in coalesced
211-
order.
212-
213-
All input args must be on the same device.
214-
215-
Args:
216-
old_nd (Tensor): [S x n_old] tensor of N-D indices
217-
new_nd (Tensor): [S x n_new] tensor of N-D indices
218-
old_values (Tensor): [n_old, ...] tensor of values
219-
new_values (Tensor): [n_new, ...] tensor of values
220-
insertion_positions (Tensor): [n_new] tensor of insertion positions in
221-
old_linear for each element in new_linear
222-
223-
Returns:
224-
merged_nd: [S, n_old+n_new]
225-
merged_values: [n_old+n_new, ...]
226-
"""
227-
device = old_nd.device
228-
n_old, n_new = old_nd.size(1), new_nd.size(1)
229-
n_total = n_old + n_new
230-
231-
# determine final positions of new values
232-
# account for previous insertions to get final positions of new rows
233-
new_positions = insertion_positions + torch.arange(
234-
n_new, device=device, dtype=insertion_positions.dtype
235-
)
236-
237-
# determine final positions of old values by counting how many new values are
238-
# inserted before each old value
239-
hist = torch.bincount(insertion_positions, minlength=n_old + 1)
240-
old_shift = torch.cumsum(hist[:-1], 0)
241-
old_positions = torch.arange(n_old, device=device) + old_shift
242-
243-
# allocate output tensors
244-
merged_nd = old_nd.new_empty(old_nd.size(0), n_total)
245-
merged_values = old_values.new_empty((n_total,) + old_values.shape[1:])
246-
247-
# insert values
248-
merged_nd[:, old_positions] = old_nd
249-
merged_nd[:, new_positions] = new_nd
250-
merged_values[old_positions] = old_values
251-
merged_values[new_positions] = new_values
252-
253-
return merged_nd, merged_values

0 commit comments

Comments
 (0)