8
8
)
9
9
10
10
11
+ # @torch.jit.script
12
+ def _merge_sorted (
13
+ old_nd : Tensor ,
14
+ new_nd : Tensor ,
15
+ old_values : Tensor ,
16
+ new_values : Tensor ,
17
+ insertion_positions : Tensor ,
18
+ ) -> tuple [Tensor , Tensor ]:
19
+ """Merges two sorted sequences of sparse indices/values and return them in coalesced
20
+ order.
21
+
22
+ All input args must be on the same device.
23
+
24
+ Args:
25
+ old_nd (Tensor): [S x n_old] tensor of N-D indices
26
+ new_nd (Tensor): [S x n_new] tensor of N-D indices
27
+ old_values (Tensor): [n_old, ...] tensor of values
28
+ new_values (Tensor): [n_new, ...] tensor of values
29
+ insertion_positions (Tensor): [n_new] tensor of insertion positions in
30
+ old_linear for each element in new_linear
31
+
32
+ Returns:
33
+ merged_nd: [S, n_old+n_new]
34
+ merged_values: [n_old+n_new, ...]
35
+ """
36
+ device = old_nd .device
37
+ n_old , n_new = old_nd .size (1 ), new_nd .size (1 )
38
+ n_total = n_old + n_new
39
+
40
+ # determine final positions of new values
41
+ # account for previous insertions to get final positions of new rows
42
+ new_positions = insertion_positions + torch .arange (
43
+ n_new , device = device , dtype = insertion_positions .dtype
44
+ )
45
+
46
+ # determine final positions of old values by counting how many new values are
47
+ # inserted before each old value
48
+ hist = torch .bincount (insertion_positions , minlength = n_old + 1 )
49
+ old_shift = torch .cumsum (hist [:- 1 ], 0 )
50
+ old_positions = torch .arange (n_old , device = device ) + old_shift
51
+
52
+ # allocate output tensors
53
+ merged_nd = old_nd .new_empty (old_nd .size (0 ), n_total )
54
+ merged_values = old_values .new_empty ((n_total ,) + old_values .shape [1 :])
55
+
56
+ # insert values
57
+ merged_nd [:, old_positions ] = old_nd
58
+ merged_nd [:, new_positions ] = new_nd
59
+ merged_values [old_positions ] = old_values
60
+ merged_values [new_positions ] = new_values
61
+
62
+ return merged_nd , merged_values
63
+
64
+
65
+ # @torch.jit.script
11
66
def scatter_to_sparse_tensor (
12
67
sparse_tensor : Tensor ,
13
68
index_tensor : Tensor ,
@@ -21,12 +76,13 @@ def scatter_to_sparse_tensor(
21
76
sparse tensor.
22
77
23
78
Args:
24
- sparse_tensor (Tensor): Sparse tensor of dimension ..., M; where ... are
25
- S leading sparse dimensions and M is the dense dimension
26
- index_tensor (Tensor): Long tensor of dimension ..., S; where ... are
27
- leading batch dimensions.
28
- values (Tensor): Tensor of dimension ..., M; where ... are leading
29
- batch dimensions and M is the dense dimension
79
+ sparse_tensor (Tensor): Sparse tensor of dimension
80
+ [s0, s1, s2, ..., d0, d1, d2, ...]; where s0, s1, ... are
81
+ S leading sparse dimensions and d0, d1, d2, ... are D dense dimensions.
82
+ index_tensor (Tensor): Long tensor of dimension [b0, b1, b2, ..., S]; where
83
+ b0, b1, b2, ... are B leading batch dimensions.
84
+ values (Tensor): Tensor of dimension [b0, b1, b2, ... d0, d1, d2, ...], where
85
+ dimensions are as above.
30
86
check_all_specified (bool): If True, this function will throw a ValueError
31
87
if any of the indices specified in index_tensor are not already present
32
88
in sparse_tensor. Default: False.
@@ -105,8 +161,15 @@ def scatter_to_sparse_tensor(
105
161
index_tensor = torch .cat (index_tensor .unbind ())
106
162
values = torch .cat (values .unbind ())
107
163
108
- assert index_tensor .shape [:- 1 ] == values .shape [:- 1 ]
109
- assert sparse_tensor .dense_dim () == values .ndim - 1
164
+ dense_dim = sparse_tensor .dense_dim ()
165
+ sparse_dim = sparse_tensor .sparse_dim ()
166
+ values_batch_dims = values .shape [:- dense_dim ] if dense_dim else values .shape
167
+ if index_tensor .shape [:- 1 ] != values_batch_dims :
168
+ raise ValueError (
169
+ "Expected matching batch dims for `index_tensor` and `values`, but got "
170
+ f"batch dims { index_tensor .shape [:- 1 ]} and "
171
+ f"{ values_batch_dims } , respectively."
172
+ )
110
173
111
174
sparse_tensor = sparse_tensor .coalesce ()
112
175
sparse_tensor_values = sparse_tensor .values ()
@@ -151,19 +214,24 @@ def scatter_to_sparse_tensor(
151
214
new_values = values [~ is_specified_mask ]
152
215
153
216
# Get sparse shape info for linearization
154
- sparse_dim = sparse_tensor .sparse_dim ()
155
217
sparse_sizes = torch .tensor (
156
218
sparse_tensor .shape [:sparse_dim ], device = sparse_tensor .device
157
219
)
158
220
221
+ if (new_nd_indices >= sparse_sizes .unsqueeze (0 )).any ():
222
+ raise ValueError (
223
+ "`index_tensor` has indices that are out of bounds of the original "
224
+ f"sparse tensor's sparse shape ({ sparse_sizes } )."
225
+ )
226
+
159
227
# Obtain linearized versions of all indices for sorting
160
228
old_indices_nd = sparse_tensor .indices ()
161
229
linear_offsets = _make_linear_offsets (sparse_sizes )
162
230
new_indices_lin : Tensor = (new_nd_indices * linear_offsets ).sum (- 1 )
163
231
164
232
# Find duplicate linear indices
165
- unique_new_indices_lin , inverse = new_indices_lin .unique (
166
- sorted = True , return_inverse = True
233
+ unique_new_indices_lin , inverse = torch .unique (
234
+ new_indices_lin , sorted = True , return_inverse = True
167
235
)
168
236
169
237
# Use inverse of indices unique to write to new values tensor and tensor of
@@ -198,56 +266,3 @@ def scatter_to_sparse_tensor(
198
266
device = sparse_tensor .device ,
199
267
is_coalesced = True ,
200
268
)
201
-
202
-
203
- def _merge_sorted (
204
- old_nd : Tensor ,
205
- new_nd : Tensor ,
206
- old_values : Tensor ,
207
- new_values : Tensor ,
208
- insertion_positions : Tensor ,
209
- ) -> tuple [Tensor , Tensor ]:
210
- """Merges two sorted sequences of sparse indices/values and return them in coalesced
211
- order.
212
-
213
- All input args must be on the same device.
214
-
215
- Args:
216
- old_nd (Tensor): [S x n_old] tensor of N-D indices
217
- new_nd (Tensor): [S x n_new] tensor of N-D indices
218
- old_values (Tensor): [n_old, ...] tensor of values
219
- new_values (Tensor): [n_new, ...] tensor of values
220
- insertion_positions (Tensor): [n_new] tensor of insertion positions in
221
- old_linear for each element in new_linear
222
-
223
- Returns:
224
- merged_nd: [S, n_old+n_new]
225
- merged_values: [n_old+n_new, ...]
226
- """
227
- device = old_nd .device
228
- n_old , n_new = old_nd .size (1 ), new_nd .size (1 )
229
- n_total = n_old + n_new
230
-
231
- # determine final positions of new values
232
- # account for previous insertions to get final positions of new rows
233
- new_positions = insertion_positions + torch .arange (
234
- n_new , device = device , dtype = insertion_positions .dtype
235
- )
236
-
237
- # determine final positions of old values by counting how many new values are
238
- # inserted before each old value
239
- hist = torch .bincount (insertion_positions , minlength = n_old + 1 )
240
- old_shift = torch .cumsum (hist [:- 1 ], 0 )
241
- old_positions = torch .arange (n_old , device = device ) + old_shift
242
-
243
- # allocate output tensors
244
- merged_nd = old_nd .new_empty (old_nd .size (0 ), n_total )
245
- merged_values = old_values .new_empty ((n_total ,) + old_values .shape [1 :])
246
-
247
- # insert values
248
- merged_nd [:, old_positions ] = old_nd
249
- merged_nd [:, new_positions ] = new_nd
250
- merged_values [old_positions ] = old_values
251
- merged_values [new_positions ] = new_values
252
-
253
- return merged_nd , merged_values
0 commit comments