Skip to content

Commit c78c19f

Browse files
feat: add power of 2 length, spectrogram example
1 parent e9629a6 commit c78c19f

File tree

4 files changed

+24
-8
lines changed

4 files changed

+24
-8
lines changed

IMAGE.png

329 KB
Loading

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
# CQT - PyTorch
33

4+
45
An invertible and differentiable implementation of the Constant-Q Transform (CQT) using Non-stationary Gabor Transform (NSGT), in PyTorch.
56

67
```bash
@@ -27,7 +28,11 @@ z = transform.encode(x) # [1, 2, 455, 2796] = [batch_size, channels, frequencies
2728
y = transform.decode(z) # [1, 1, 262144]
2829
```
2930

31+
### Example CQT spectrogram (z)
32+
<img src="./IMAGE.png"></img>
33+
3034
## TODO
35+
* [x] Power of 2 length (with `power_of_2_length` constructor arg).
3136
* [ ] Understand why/if inverse window is necessary.
3237
* [ ] Allow variable audio lengths by chunking.
3338

cqt_pytorch/cqt.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
from math import floor
1+
from math import ceil, floor
22

33
import torch
44
import torch.nn.functional as F
55
from torch import Tensor, nn
66

77

8+
def next_power_of_2(x: Tensor) -> int:
9+
return 2 ** ceil(x.item()).bit_length()
10+
11+
812
def get_center_frequencies(
913
num_octaves: int, num_bins_per_octave: int, sample_rate: int # C # B # Xi_s
1014
) -> Tensor: # Xi_k for k in [1, 2*K+1]
@@ -51,25 +55,27 @@ def get_bandwidths(
5155
return bandwidths_all
5256

5357

54-
def get_windows_range_indices(lengths: Tensor, positions: Tensor) -> Tensor:
58+
def get_windows_range_indices(
59+
lengths: Tensor, positions: Tensor, power_of_2_length: bool
60+
) -> Tensor:
5561
"""Compute windowing tensor of indices"""
5662
num_bins = lengths.shape[0] // 2
57-
max_length = lengths.max()
63+
max_length = next_power_of_2(lengths.max()) if power_of_2_length else lengths.max()
5864
ranges = []
5965
for i in range(num_bins):
6066
start = positions[i] - max_length
6167
ranges += [torch.arange(start=start, end=start + max_length)] # type: ignore
6268
return torch.stack(ranges, dim=0).long()
6369

6470

65-
def get_windows(lengths: Tensor) -> Tensor:
71+
def get_windows(lengths: Tensor, power_of_2_length: bool) -> Tensor:
6672
"""Compute tensor of stacked (centered) windows"""
6773
num_bins = lengths.shape[0] // 2
68-
max_length = lengths.max()
74+
max_length = next_power_of_2(lengths.max()) if power_of_2_length else lengths.max()
6975
windows = []
7076
for length in lengths[:num_bins]:
7177
# Pad windows left and right to center them
72-
pad_left = floor(max_length / 2 - length / 2)
78+
pad_left = floor(max_length / 2 - length / 2) # type: ignore
7379
pad_right = int(max_length - length - pad_left)
7480
windows += [F.pad(torch.hann_window(int(length)), pad=(pad_left, pad_right))]
7581
return torch.stack(windows, dim=0)
@@ -87,6 +93,7 @@ def __init__(
8793
num_bins_per_octave: int,
8894
sample_rate: int,
8995
block_length: int,
96+
power_of_2_length: bool = False,
9097
):
9198
super().__init__()
9299
self.block_length = block_length
@@ -111,10 +118,14 @@ def __init__(
111118
get_windows_range_indices(
112119
lengths=window_lengths,
113120
positions=torch.round(frequencies * block_length / sample_rate),
121+
power_of_2_length=power_of_2_length,
114122
),
115123
)
116124

117-
self.register_buffer("windows", get_windows(lengths=window_lengths))
125+
self.register_buffer(
126+
"windows",
127+
get_windows(lengths=window_lengths, power_of_2_length=power_of_2_length),
128+
)
118129

119130
self.register_buffer(
120131
"windows_inverse",

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="cqt-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.2",
6+
version="0.0.3",
77
license="MIT",
88
description="CQT Pytorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)