Skip to content

Commit 5ca1ae0

Browse files
committed
normalize signal's shapes to #signal x #features x #nodes
1 parent cc9f508 commit 5ca1ae0

File tree

3 files changed

+95
-4
lines changed

3 files changed

+95
-4
lines changed

pygsp/filters/approximations.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,41 @@ def _evaluate(self, x, method):
213213

214214
def _filter(self, s, method, _):
215215

216+
# TODO: signal normalization will move to Filter.filter()
217+
218+
# Dimension 3: number of nodes.
219+
if s.shape[-1] != self.G.N:
220+
raise ValueError('The last dimension should be {}, '
221+
'the number of nodes. '
222+
'Got instead a signal of shape '
223+
'{}.'.format(self.G.N, s.shape))
224+
225+
# Dimension 2: number of input features.
226+
if s.ndim == 1:
227+
s = np.expand_dims(s, 0)
228+
if s.shape[-2] != self.n_features_in:
229+
if self.n_features_in == 1:
230+
# Dimension can be omitted if there's 1 input feature.
231+
s = np.expand_dims(s, -2)
232+
else:
233+
raise ValueError('The second to last dimension should be {}, '
234+
'the number of input features. '
235+
'Got instead a signal of shape '
236+
'{}.'.format(self.n_features_in, s.shape))
237+
238+
# Dimension 1: number of independent signals.
239+
if s.ndim < 3:
240+
s = np.expand_dims(s, 0)
241+
242+
if s.ndim > 3:
243+
raise ValueError('Signals should have at most 3 dimensions: '
244+
'#signals x #features x #nodes.')
245+
246+
assert s.ndim == 3
247+
assert s.shape[2] == self.G.N # Number of nodes.
248+
assert s.shape[1] == self.n_features_in # Number of input features.
249+
# n_signals = s.shape[0]
250+
216251
L = self.scale_operator(self.G.L, self.G.lmax)
217252

218253
# Recursive and clenshaw are similarly fast.

pygsp/filters/filter.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,11 @@ def filter(self, s, method=None, order=30):
273273
"""
274274
if not sparse.issparse(s):
275275
s = np.asarray(s) # For iterables.
276-
return self._filter(s, method, order)
276+
277+
s = self._filter(s, method, order)
278+
279+
# Return a 1D signal if e.g. a 1D signal was filtered by one filter.
280+
return s.squeeze()
277281

278282
def _filter(self, s, method='chebyshev', order=30):
279283
r"""Default implementation for filters defined as kernel functions."""
@@ -344,8 +348,7 @@ def _filter(self, s, method='chebyshev', order=30):
344348
else:
345349
raise ValueError('Unknown method {}.'.format(method))
346350

347-
# Return a 1D signal if e.g. a 1D signal was filtered by one filter.
348-
return s.squeeze()
351+
return s
349352

350353
def analyze(self, s, method='chebyshev', order=30):
351354
r"""Convenience alias to :meth:`filter`."""

pygsp/tests/test_filters.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_evaluation_methods(self, K=30, F=5, N=100):
280280

281281
def test_filter_identity(self, M=10, c=2.3):
282282
r"""Test that filtering with c0 only scales the signal."""
283-
x = self._rs.uniform(size=(M, 1, self._G.N))
283+
x = self._rs.uniform(size=(M, self._G.N))
284284
f = filters.Chebyshev(self._G, c)
285285
y = f.filter(x, method='recursive')
286286
np.testing.assert_equal(y, c * x)
@@ -331,3 +331,56 @@ def test_approximations(self, N=100, K=20):
331331
y1 = f1.filter(x.T).T
332332
y2 = f2.filter(x)
333333
np.testing.assert_allclose(y2.squeeze(), y1)
334+
335+
def test_shape_normalization(self):
336+
"""Test that signal's shapes are properly normalized."""
337+
# TODO: should also test filters which are not approximations.
338+
339+
def test_normalization(M, Fin, Fout, K=7):
340+
341+
def test_shape(y, M, Fout, N=self._G.N):
342+
"""Test that filtered signals are squeezed."""
343+
if Fout == 1 and M == 1:
344+
self.assertEqual(y.shape, (N,))
345+
elif Fout == 1:
346+
self.assertEqual(y.shape, (M, N))
347+
elif M == 1:
348+
self.assertEqual(y.shape, (Fout, N))
349+
else:
350+
self.assertEqual(y.shape, (M, Fout, N))
351+
352+
coefficients = self._rs.uniform(size=(K, Fout, Fin))
353+
f = filters.Chebyshev(self._G, coefficients)
354+
assert f.shape == (Fin, Fout)
355+
assert (f.n_features_in, f.n_features_out) == (Fin, Fout)
356+
357+
x = self._rs.uniform(size=(M, Fin, self._G.N))
358+
y = f.filter(x)
359+
test_shape(y, M, Fout)
360+
361+
if Fin == 1 or M == 1:
362+
# It only makes sense to squeeze if one dimension is unitary.
363+
x = x.squeeze()
364+
y = f.filter(x)
365+
test_shape(y, M, Fout)
366+
367+
# Test all possible correct combinations of input and output signals.
368+
for M in [1, 9]:
369+
for Fin in [1, 3]:
370+
for Fout in [1, 5]:
371+
test_normalization(M, Fin, Fout)
372+
373+
# Test failure cases.
374+
M, Fin, Fout, K = 9, 3, 5, 7
375+
coefficients = self._rs.uniform(size=(K, Fout, Fin))
376+
f = filters.Chebyshev(self._G, coefficients)
377+
x = self._rs.uniform(size=(M, Fin, 2))
378+
self.assertRaises(ValueError, f.filter, x)
379+
x = self._rs.uniform(size=(M, 2, self._G.N))
380+
self.assertRaises(ValueError, f.filter, x)
381+
x = self._rs.uniform(size=(2, self._G.N))
382+
self.assertRaises(ValueError, f.filter, x)
383+
x = self._rs.uniform(size=(self._G.N))
384+
self.assertRaises(ValueError, f.filter, x)
385+
x = self._rs.uniform(size=(2, M, Fin, self._G.N))
386+
self.assertRaises(ValueError, f.filter, x)

0 commit comments

Comments
 (0)