Skip to content

Commit 03a6ecc

Browse files
authored
Merge pull request #7 from DiffAPF:fix-avg_coef-shape
fix: correct tensor broadcasting in avg function
2 parents 70b6590 + 7ab8ed1 commit 03a6ecc

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchcomp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def avg(rms: torch.Tensor, avg_coef: Union[torch.Tensor, float]):
8989
assert torch.all(avg_coef > 0) and torch.all(avg_coef <= 1)
9090

9191
return sample_wise_lpc(
92-
rms * avg_coef,
92+
rms * avg_coef.unsqueeze(1),
9393
avg_coef[:, None, None].broadcast_to(rms.shape + (1,)) - 1,
9494
)
9595

0 commit comments

Comments
 (0)