Skip to content

Commit 1525115

Browse files
authored
v0.2: drop supports for torch<2 (#6)
* feat: torch2 autograd scheme * drop support for torch<2 * feat: use torchlpc for fast avg, drop torchaudio dependency * feat: vmap * test: vmap
1 parent d044041 commit 1525115

File tree

5 files changed

+97
-44
lines changed

5 files changed

+97
-44
lines changed

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
numpy
22
numba
3-
torch
4-
torchaudio
3+
torch>=2
54
torchlpc

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
long_description_content_type="text/markdown",
2020
url="https://github.com/yoyololicon/torchcomp",
2121
packages=["torchcomp"],
22-
install_requires=["torch", "torchaudio", "torchlpc", "numpy", "numba"],
22+
install_requires=["torch>=2", "torchlpc", "numpy", "numba"],
2323
classifiers=[
2424
"Programming Language :: Python :: 3",
2525
"License :: OSI Approved :: MIT License",

tests/test_vmap.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch.func import jacfwd
4+
import pytest
5+
from torchcomp.core import compressor_core
6+
7+
8+
from .test_grad import create_test_inputs
9+
10+
11+
@pytest.mark.parametrize(
12+
"device",
13+
[
14+
"cpu",
15+
pytest.param(
16+
"cuda",
17+
marks=pytest.mark.skipif(
18+
not torch.cuda.is_available(), reason="CUDA not available"
19+
),
20+
),
21+
],
22+
)
23+
def test_vmap(device: str):
24+
batch_size = 4
25+
samples = 128
26+
x, zi, at, rt = tuple(x.to(device) for x in create_test_inputs(batch_size, samples))
27+
y = torch.randn_like(x)
28+
29+
x.requires_grad = True
30+
zi.requires_grad = True
31+
at.requires_grad = True
32+
rt.requires_grad = True
33+
34+
args = (x, zi, at, rt)
35+
36+
def func(x, zi, at, rt):
37+
return F.mse_loss(compressor_core(x, zi, at, rt), y)
38+
39+
jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args)
40+
41+
loss = func(*args)
42+
loss.backward()
43+
for jac, arg in zip(jacs, args):
44+
assert torch.allclose(jac, arg.grad)

torchcomp/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33
from typing import Union
4-
from torchaudio.functional import lfilter
4+
from torchlpc import sample_wise_lpc
55

66
from .core import compressor_core
77

@@ -88,11 +88,9 @@ def avg(rms: torch.Tensor, avg_coef: Union[torch.Tensor, float]):
8888
).broadcast_to(rms.shape[0])
8989
assert torch.all(avg_coef > 0) and torch.all(avg_coef <= 1)
9090

91-
return lfilter(
92-
rms,
93-
torch.stack([torch.ones_like(avg_coef), avg_coef - 1], 1),
94-
torch.stack([avg_coef, torch.zeros_like(avg_coef)], 1),
95-
False,
91+
return sample_wise_lpc(
92+
rms * avg_coef,
93+
avg_coef[:, None, None].broadcast_to(rms.shape + (1,)) - 1,
9694
)
9795

9896

torchcomp/core.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def compressor_cuda_kernel(
1818
B: int,
1919
T: int,
2020
):
21-
b = cuda.blockIdx.x
22-
i = cuda.threadIdx.x
21+
b: int = cuda.blockIdx.x
22+
i: int = cuda.threadIdx.x
2323

2424
if b >= B or i > 0:
2525
return
@@ -93,8 +93,8 @@ def compressor_cuda(
9393
class CompressorFunction(Function):
9494
@staticmethod
9595
def forward(
96-
ctx: Any, x: torch.Tensor, zi: torch.Tensor, at: torch.Tensor, rt: torch.Tensor
97-
) -> torch.Tensor:
96+
x: torch.Tensor, zi: torch.Tensor, at: torch.Tensor, rt: torch.Tensor
97+
) -> Tuple[torch.Tensor, torch.Tensor]:
9898
if x.is_cuda:
9999
y, at_mask = compressor_cuda(
100100
x.detach(), zi.detach(), at.detach(), rt.detach()
@@ -108,19 +108,21 @@ def forward(
108108
)
109109
y = torch.from_numpy(y).to(x.device)
110110
at_mask = torch.from_numpy(at_mask).to(x.device)
111-
ctx.save_for_backward(x, y, zi, at, rt, at_mask)
111+
return y, at_mask
112112

113-
# for jvp
114-
ctx.x = x
115-
ctx.y = y
116-
ctx.zi = zi
117-
ctx.at = at
118-
ctx.rt = rt
119-
ctx.at_mask = at_mask
120-
return y
113+
@staticmethod
114+
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any:
115+
x, zi, at, rt = inputs
116+
y, at_mask = output
117+
ctx.mark_non_differentiable(at_mask)
118+
ctx.save_for_backward(x, y, zi, at, rt, at_mask)
119+
ctx.save_for_forward(x, y, zi, at, rt, at_mask)
120+
return ctx
121121

122122
@staticmethod
123-
def backward(ctx: Any, grad_y: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
123+
def backward(
124+
ctx: Any, grad_y: torch.Tensor, _
125+
) -> Tuple[Optional[torch.Tensor], ...]:
124126
x, y, zi, at, rt, at_mask = ctx.saved_tensors
125127
grad_x = grad_zi = grad_at = grad_rt = None
126128

@@ -153,19 +155,6 @@ def backward(ctx: Any, grad_y: torch.Tensor) -> Tuple[Optional[torch.Tensor], ..
153155
if ctx.needs_input_grad[3]:
154156
grad_rt = torch.where(~at_mask, grad_combined, 0.0).sum(1)
155157

156-
if hasattr(ctx, "y"):
157-
del ctx.y
158-
if hasattr(ctx, "x"):
159-
del ctx.x
160-
if hasattr(ctx, "zi"):
161-
del ctx.zi
162-
if hasattr(ctx, "at"):
163-
del ctx.at
164-
if hasattr(ctx, "rt"):
165-
del ctx.rt
166-
if hasattr(ctx, "at_mask"):
167-
del ctx.at_mask
168-
169158
return grad_x, grad_zi, grad_at, grad_rt
170159

171160
@staticmethod
@@ -175,12 +164,13 @@ def jvp(
175164
grad_zi: torch.Tensor,
176165
grad_at: torch.Tensor,
177166
grad_rt: torch.Tensor,
178-
) -> torch.Tensor:
179-
x, y, zi, at, rt, at_mask = ctx.x, ctx.y, ctx.zi, ctx.at, ctx.rt, ctx.at_mask
167+
) -> Tuple[torch.Tensor, None]:
168+
x, y, zi, at, rt, at_mask = ctx.saved_tensors
180169
coeffs = torch.where(at_mask, at.unsqueeze(1), rt.unsqueeze(1))
181170

182171
fwd_x = 0 if grad_x is None else grad_x * coeffs
183172

173+
fwd_combined: torch.Tensor
184174
if grad_at is None and grad_rt is None:
185175
fwd_combined = fwd_x
186176
else:
@@ -192,13 +182,35 @@ def jvp(
192182
fwd_combined = fwd_x + grad_beta * (
193183
x - torch.cat([zi.unsqueeze(1), y[:, :-1]], dim=1)
194184
)
185+
return (
186+
sample_wise_lpc(
187+
fwd_combined,
188+
coeffs.unsqueeze(2) - 1,
189+
grad_zi if grad_zi is None else grad_zi.unsqueeze(1),
190+
),
191+
None,
192+
)
195193

196-
del ctx.x, ctx.y, ctx.zi, ctx.at, ctx.rt, ctx.at_mask
197-
return sample_wise_lpc(
198-
fwd_combined,
199-
coeffs.unsqueeze(2) - 1,
200-
grad_zi if grad_zi is None else grad_zi.unsqueeze(1),
194+
@staticmethod
195+
def vmap(info, in_dims, *args):
196+
def maybe_expand_bdim_at_front(x, x_bdim):
197+
if x_bdim is None:
198+
return x.expand(info.batch_size, *x.shape)
199+
return x.movedim(x_bdim, 0)
200+
201+
x, zi, at, rt = tuple(
202+
map(
203+
lambda x: x.reshape(-1, *x.shape[2:]),
204+
map(maybe_expand_bdim_at_front, args, in_dims),
205+
)
201206
)
202207

208+
y, at_mask = CompressorFunction.apply(x, zi, at, rt)
209+
return (
210+
y.reshape(info.batch_size, -1, *y.shape[1:]),
211+
at_mask.reshape(info.batch_size, -1, *at_mask.shape[1:]),
212+
), 0
213+
203214

204-
compressor_core: Callable = CompressorFunction.apply
215+
def compressor_core(*args, **kwargs) -> torch.Tensor:
216+
return CompressorFunction.apply(*args, **kwargs)[0]

0 commit comments

Comments
 (0)