1- from math import floor
1+ from math import ceil , floor
22
33import torch
44import torch .nn .functional as F
55from torch import Tensor , nn
66
77
8+ def next_power_of_2 (x : Tensor ) -> int :
9+ return 2 ** ceil (x .item ()).bit_length ()
10+
11+
812def get_center_frequencies (
913 num_octaves : int , num_bins_per_octave : int , sample_rate : int # C # B # Xi_s
1014) -> Tensor : # Xi_k for k in [1, 2*K+1]
@@ -51,25 +55,27 @@ def get_bandwidths(
5155 return bandwidths_all
5256
5357
54- def get_windows_range_indices (lengths : Tensor , positions : Tensor ) -> Tensor :
58+ def get_windows_range_indices (
59+ lengths : Tensor , positions : Tensor , power_of_2_length : bool
60+ ) -> Tensor :
5561 """Compute windowing tensor of indices"""
5662 num_bins = lengths .shape [0 ] // 2
57- max_length = lengths .max ()
63+ max_length = next_power_of_2 ( lengths . max ()) if power_of_2_length else lengths .max ()
5864 ranges = []
5965 for i in range (num_bins ):
6066 start = positions [i ] - max_length
6167 ranges += [torch .arange (start = start , end = start + max_length )] # type: ignore
6268 return torch .stack (ranges , dim = 0 ).long ()
6369
6470
65- def get_windows (lengths : Tensor ) -> Tensor :
71+ def get_windows (lengths : Tensor , power_of_2_length : bool ) -> Tensor :
6672 """Compute tensor of stacked (centered) windows"""
6773 num_bins = lengths .shape [0 ] // 2
68- max_length = lengths .max ()
74+ max_length = next_power_of_2 ( lengths . max ()) if power_of_2_length else lengths .max ()
6975 windows = []
7076 for length in lengths [:num_bins ]:
7177 # Pad windows left and right to center them
72- pad_left = floor (max_length / 2 - length / 2 )
78+ pad_left = floor (max_length / 2 - length / 2 ) # type: ignore
7379 pad_right = int (max_length - length - pad_left )
7480 windows += [F .pad (torch .hann_window (int (length )), pad = (pad_left , pad_right ))]
7581 return torch .stack (windows , dim = 0 )
@@ -87,6 +93,7 @@ def __init__(
8793 num_bins_per_octave : int ,
8894 sample_rate : int ,
8995 block_length : int ,
96+ power_of_2_length : bool = False ,
9097 ):
9198 super ().__init__ ()
9299 self .block_length = block_length
@@ -111,10 +118,14 @@ def __init__(
111118 get_windows_range_indices (
112119 lengths = window_lengths ,
113120 positions = torch .round (frequencies * block_length / sample_rate ),
121+ power_of_2_length = power_of_2_length ,
114122 ),
115123 )
116124
117- self .register_buffer ("windows" , get_windows (lengths = window_lengths ))
125+ self .register_buffer (
126+ "windows" ,
127+ get_windows (lengths = window_lengths , power_of_2_length = power_of_2_length ),
128+ )
118129
119130 self .register_buffer (
120131 "windows_inverse" ,
0 commit comments