@@ -18,8 +18,8 @@ def compressor_cuda_kernel(
18
18
B : int ,
19
19
T : int ,
20
20
):
21
- b = cuda .blockIdx .x
22
- i = cuda .threadIdx .x
21
+ b : int = cuda .blockIdx .x
22
+ i : int = cuda .threadIdx .x
23
23
24
24
if b >= B or i > 0 :
25
25
return
@@ -93,8 +93,8 @@ def compressor_cuda(
93
93
class CompressorFunction (Function ):
94
94
@staticmethod
95
95
def forward (
96
- ctx : Any , x : torch .Tensor , zi : torch .Tensor , at : torch .Tensor , rt : torch .Tensor
97
- ) -> torch .Tensor :
96
+ x : torch .Tensor , zi : torch .Tensor , at : torch .Tensor , rt : torch .Tensor
97
+ ) -> Tuple [ torch .Tensor , torch . Tensor ] :
98
98
if x .is_cuda :
99
99
y , at_mask = compressor_cuda (
100
100
x .detach (), zi .detach (), at .detach (), rt .detach ()
@@ -108,19 +108,21 @@ def forward(
108
108
)
109
109
y = torch .from_numpy (y ).to (x .device )
110
110
at_mask = torch .from_numpy (at_mask ).to (x .device )
111
- ctx . save_for_backward ( x , y , zi , at , rt , at_mask )
111
+ return y , at_mask
112
112
113
- # for jvp
114
- ctx . x = x
115
- ctx . y = y
116
- ctx . zi = zi
117
- ctx .at = at
118
- ctx .rt = rt
119
- ctx .at_mask = at_mask
120
- return y
113
+ @ staticmethod
114
+ def setup_context ( ctx : Any , inputs : Tuple [ Any , ...], output : Any ) -> Any :
115
+ x , zi , at , rt = inputs
116
+ y , at_mask = output
117
+ ctx .mark_non_differentiable ( at_mask )
118
+ ctx .save_for_backward ( x , y , zi , at , rt , at_mask )
119
+ ctx .save_for_forward ( x , y , zi , at , rt , at_mask )
120
+ return ctx
121
121
122
122
@staticmethod
123
- def backward (ctx : Any , grad_y : torch .Tensor ) -> Tuple [Optional [torch .Tensor ], ...]:
123
+ def backward (
124
+ ctx : Any , grad_y : torch .Tensor , _
125
+ ) -> Tuple [Optional [torch .Tensor ], ...]:
124
126
x , y , zi , at , rt , at_mask = ctx .saved_tensors
125
127
grad_x = grad_zi = grad_at = grad_rt = None
126
128
@@ -153,19 +155,6 @@ def backward(ctx: Any, grad_y: torch.Tensor) -> Tuple[Optional[torch.Tensor], ..
153
155
if ctx .needs_input_grad [3 ]:
154
156
grad_rt = torch .where (~ at_mask , grad_combined , 0.0 ).sum (1 )
155
157
156
- if hasattr (ctx , "y" ):
157
- del ctx .y
158
- if hasattr (ctx , "x" ):
159
- del ctx .x
160
- if hasattr (ctx , "zi" ):
161
- del ctx .zi
162
- if hasattr (ctx , "at" ):
163
- del ctx .at
164
- if hasattr (ctx , "rt" ):
165
- del ctx .rt
166
- if hasattr (ctx , "at_mask" ):
167
- del ctx .at_mask
168
-
169
158
return grad_x , grad_zi , grad_at , grad_rt
170
159
171
160
@staticmethod
@@ -175,12 +164,13 @@ def jvp(
175
164
grad_zi : torch .Tensor ,
176
165
grad_at : torch .Tensor ,
177
166
grad_rt : torch .Tensor ,
178
- ) -> torch .Tensor :
179
- x , y , zi , at , rt , at_mask = ctx .x , ctx . y , ctx . zi , ctx . at , ctx . rt , ctx . at_mask
167
+ ) -> Tuple [ torch .Tensor , None ] :
168
+ x , y , zi , at , rt , at_mask = ctx .saved_tensors
180
169
coeffs = torch .where (at_mask , at .unsqueeze (1 ), rt .unsqueeze (1 ))
181
170
182
171
fwd_x = 0 if grad_x is None else grad_x * coeffs
183
172
173
+ fwd_combined : torch .Tensor
184
174
if grad_at is None and grad_rt is None :
185
175
fwd_combined = fwd_x
186
176
else :
@@ -192,13 +182,35 @@ def jvp(
192
182
fwd_combined = fwd_x + grad_beta * (
193
183
x - torch .cat ([zi .unsqueeze (1 ), y [:, :- 1 ]], dim = 1 )
194
184
)
185
+ return (
186
+ sample_wise_lpc (
187
+ fwd_combined ,
188
+ coeffs .unsqueeze (2 ) - 1 ,
189
+ grad_zi if grad_zi is None else grad_zi .unsqueeze (1 ),
190
+ ),
191
+ None ,
192
+ )
195
193
196
- del ctx .x , ctx .y , ctx .zi , ctx .at , ctx .rt , ctx .at_mask
197
- return sample_wise_lpc (
198
- fwd_combined ,
199
- coeffs .unsqueeze (2 ) - 1 ,
200
- grad_zi if grad_zi is None else grad_zi .unsqueeze (1 ),
194
+ @staticmethod
195
+ def vmap (info , in_dims , * args ):
196
+ def maybe_expand_bdim_at_front (x , x_bdim ):
197
+ if x_bdim is None :
198
+ return x .expand (info .batch_size , * x .shape )
199
+ return x .movedim (x_bdim , 0 )
200
+
201
+ x , zi , at , rt = tuple (
202
+ map (
203
+ lambda x : x .reshape (- 1 , * x .shape [2 :]),
204
+ map (maybe_expand_bdim_at_front , args , in_dims ),
205
+ )
201
206
)
202
207
208
+ y , at_mask = CompressorFunction .apply (x , zi , at , rt )
209
+ return (
210
+ y .reshape (info .batch_size , - 1 , * y .shape [1 :]),
211
+ at_mask .reshape (info .batch_size , - 1 , * at_mask .shape [1 :]),
212
+ ), 0
213
+
203
214
204
- compressor_core : Callable = CompressorFunction .apply
215
+ def compressor_core (* args , ** kwargs ) -> torch .Tensor :
216
+ return CompressorFunction .apply (* args , ** kwargs )[0 ]
0 commit comments