Skip to content

Commit 2d5a37f

Browse files
✨ New features
* Add call_and_ladj method to transformations to improve flow efficiency * Add continuous normalizing flow (CNF) * Add free-form Jacobian (FFJ) transformation * Add odeint solver * Make bisection auto-differentiable
1 parent 2dd249f commit 2d5a37f

File tree

9 files changed

+818
-194
lines changed

9 files changed

+818
-194
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setuptools.setup(
1212
name='zuko',
13-
version='0.0.6',
13+
version='0.0.7',
1414
packages=setuptools.find_packages(),
1515
description='Normalizing flows in PyTorch',
1616
keywords=[

tests/test_flows.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def test_flows(tmp_path):
1414
SOSPF(3, 5),
1515
NAF(3, 5),
1616
NAF(3, 5, unconstrained=True),
17+
CNF(3, 5),
1718
]
1819

1920
for flow in flows:
@@ -24,16 +25,27 @@ def test_flows(tmp_path):
2425
assert log_p.shape == (256,), flow
2526
assert log_p.requires_grad, flow
2627

28+
flow.zero_grad(set_to_none=True)
2729
loss = -log_p.mean()
2830
loss.backward()
2931

3032
for p in flow.parameters():
31-
assert hasattr(p, 'grad'), flow
33+
assert p.grad is not None, flow
3234

3335
# Sampling
34-
z = flow(y).sample((32,))
36+
x = flow(y).sample((32,))
3537

36-
assert z.shape == (32, 3), flow
38+
assert x.shape == (32, 3), flow
39+
40+
# Reparameterization trick
41+
x = flow(y).rsample()
42+
43+
flow.zero_grad(set_to_none=True)
44+
loss = x.square().sum().sqrt()
45+
loss.backward()
46+
47+
for p in flow.parameters():
48+
assert p.grad is not None, flow
3749

3850
# Invertibility
3951
x, y = randn(256, 3), randn(256, 5)
@@ -58,7 +70,9 @@ def test_flows(tmp_path):
5870

5971
x, y = randn(3), randn(5)
6072

73+
seed = torch.seed()
6174
log_p = flow(y).log_prob(x)
75+
torch.manual_seed(seed)
6276
log_p_bis = flow_bis(y).log_prob(x)
6377

6478
assert torch.allclose(log_p, log_p_bis), flow

tests/test_transforms.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ def test_univariate_transforms():
2121
]
2222

2323
for t in ts:
24+
# Call
2425
if hasattr(t.domain, 'lower_bound'):
25-
x = torch.linspace(t.domain.lower_bound, t.domain.upper_bound, 256)
26+
x = torch.linspace(t.domain.lower_bound + 1e-2, t.domain.upper_bound - 1e-2, 256)
2627
else:
2728
x = torch.linspace(-5.0, 5.0, 256)
2829

2930
y = t(x)
3031

3132
assert x.shape == y.shape, t
3233

34+
# Inverse
3335
z = t.inv(y)
3436

3537
assert torch.allclose(x, z, atol=1e-4), t
@@ -42,7 +44,39 @@ def test_univariate_transforms():
4244

4345
ladj = torch.diag(J).abs().log()
4446

45-
assert torch.allclose(ladj, t.log_abs_det_jacobian(x, y), atol=1e-4), t
47+
assert torch.allclose(t.log_abs_det_jacobian(x, y), ladj, atol=1e-4), t
48+
49+
# Inverse Jacobian
50+
J = torch.autograd.functional.jacobian(t.inv, y)
51+
52+
assert (torch.triu(J, diagonal=1) == 0).all(), t
53+
assert (torch.tril(J, diagonal=-1) == 0).all(), t
54+
55+
ladj = torch.diag(J).abs().log()
56+
57+
assert torch.allclose(t.inv.log_abs_det_jacobian(y, z), ladj, atol=1e-4), t
58+
59+
60+
def test_FFJTransform():
61+
a = torch.randn(3)
62+
f = lambda x, t: a * x
63+
t = FFJTransform(f, time=torch.tensor(1.0))
64+
65+
# Call
66+
x = randn(256, 3)
67+
y = t(x)
68+
69+
assert x.shape == y.shape
70+
71+
# Inverse
72+
z = t.inv(y)
73+
74+
assert torch.allclose(x, z, atol=1e-4)
75+
76+
# Jacobian
77+
ladj = t.log_abs_det_jacobian(x, y)
78+
79+
assert ladj.shape == x.shape[:-1]
4680

4781

4882
def test_PermutationTransform():

tests/test_utils.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010

1111
def test_bisection():
1212
f = torch.cos
13+
y = torch.tensor(0.0)
1314
a = torch.rand(256, 1) + 2.0
1415
b = torch.rand(16)
1516

16-
x = bisection(f, a, b, n=18)
17+
x = bisection(f, y, a, b, n=18)
1718

1819
assert x.shape == (256, 16)
1920
assert torch.allclose(x, torch.tensor(math.pi / 2), atol=1e-4)
20-
assert torch.allclose(f(x), torch.tensor(0.0), atol=1e-4)
21+
assert torch.allclose(f(x), y, atol=1e-4)
2122

2223

2324
def test_broadcast():
@@ -59,17 +60,37 @@ def test_gauss_legendre():
5960
# Polynomial
6061
f = lambda x: x**5 - x**2
6162
F = lambda x: x**6 / 6 - x**3 / 3
62-
a, b = randn(2, 256)
63+
a, b = randn(2, 256, requires_grad=True)
6364

6465
area = gauss_legendre(f, a, b, n=3)
6566

66-
assert torch.allclose(F(b) - F(a), area, atol=1e-4)
67+
assert torch.allclose(area, F(b) - F(a), atol=1e-4)
6768

6869
# Gradients
69-
grad_a, grad_b = torch.autograd.functional.jacobian(
70-
lambda a, b: gauss_legendre(f, a, b).sum(),
71-
(a, b),
72-
)
70+
grad_a, grad_b = torch.autograd.grad(area.sum(), (a, b))
7371

74-
assert torch.allclose(-f(a), grad_a)
75-
assert torch.allclose(f(b), grad_b)
72+
assert torch.allclose(grad_a, -f(a), atol=1e-4)
73+
assert torch.allclose(grad_b, f(b), atol=1e-4)
74+
75+
76+
def test_odeint():
77+
# Linear
78+
alpha = torch.tensor(1.0, requires_grad=True)
79+
t = torch.tensor(3.0, requires_grad=True)
80+
81+
f = lambda x, t: -alpha * x
82+
F = lambda x, t: x * (-alpha * t).exp()
83+
84+
x0 = randn(256, 1, requires_grad=True)
85+
xt = odeint(f, x0, torch.zeros_like(t), t, phi=(alpha,))
86+
87+
assert xt.shape == x0.shape
88+
assert torch.allclose(xt, F(x0, t), atol=1e-4)
89+
90+
# Gradients
91+
grad_x0, grad_t, grad_alpha = torch.autograd.grad(xt.sum(), (x0, t, alpha))
92+
g_x0, g_t, g_alpha = torch.autograd.grad(F(x0, t).sum(), (x0, t, alpha))
93+
94+
assert torch.allclose(grad_x0, g_x0, atol=1e-4)
95+
assert torch.allclose(grad_t, g_t, atol=1e-4)
96+
assert torch.allclose(grad_alpha, g_alpha, atol=1e-4)

zuko/distributions.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
r"""Parametrizable probability distributions."""
1+
r"""Parameterizable probability distributions."""
22

33
import math
44
import torch
@@ -7,14 +7,15 @@
77
from torch import Tensor, Size
88
from torch.distributions import *
99
from torch.distributions import constraints
10+
from torch.distributions.utils import _sum_rightmost
1011
from typing import *
1112

1213

1314
Distribution._validate_args = False
1415
Distribution.arg_constraints = {}
1516

1617

17-
class NormalizingFlow(TransformedDistribution):
18+
class NormalizingFlow(Distribution):
1819
r"""Creates a normalizing flow for a random variable :math:`X` towards a base
1920
distribution :math:`p(Z)` through a series of :math:`n` invertible and differentiable
2021
transformations :math:`f_1, f_2, \dots, f_n`.
@@ -49,18 +50,77 @@ def __init__(
4950
transforms: List[Transform],
5051
base: Distribution,
5152
):
52-
super().__init__(base, [t.inv for t in reversed(transforms)])
53+
super().__init__()
54+
55+
codomain_dim = ComposeTransform(transforms).codomain.event_dim
56+
reinterpreted = codomain_dim - len(base.event_shape)
57+
58+
if reinterpreted > 0:
59+
base = Independent(base, reinterpreted)
60+
61+
self.transforms = transforms
62+
self.base = base
5363

5464
def __repr__(self) -> str:
55-
lines = [f'({i+1}): {t.inv}' for i, t in enumerate(reversed(self.transforms))]
56-
lines.append(f'(base): {self.base_dist}')
65+
lines = [f'({i + 1}): {t}' for i, t in enumerate(self.transforms)]
66+
lines.append(f'(base): {self.base}')
5767
lines = indent('\n'.join(lines), ' ')
5868

5969
return self.__class__.__name__ + '(\n' + lines + '\n)'
6070

61-
def expand(self, batch_shape: Size, new: Distribution = None) -> Distribution:
71+
@property
72+
def batch_shape(self) -> Size:
73+
return self.base.batch_shape
74+
75+
@property
76+
def event_shape(self) -> Size:
77+
shape = self.base.event_shape
78+
79+
for t in reversed(self.transforms):
80+
shape = t.inverse_shape(shape)
81+
82+
return shape
83+
84+
def expand(self, batch_shape: Size, new: Distribution = None):
6285
new = self._get_checked_instance(NormalizingFlow, new)
63-
return super().expand(batch_shape, new)
86+
new.transforms = self.transforms
87+
new.base = self.base.expand(batch_shape)
88+
89+
Distribution.__init__(new, batch_shape=batch_shape, validate_args=False)
90+
91+
return new
92+
93+
def log_prob(self, x: Tensor) -> Tensor:
94+
acc = 0
95+
event_dim = len(self.event_shape)
96+
97+
for t in self.transforms:
98+
x, ladj = t.call_and_ladj(x)
99+
acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim)
100+
event_dim += t.codomain.event_dim - t.domain.event_dim
101+
102+
return self.base.log_prob(x) + acc
103+
104+
@property
105+
def has_rsample(self) -> bool:
106+
return self.base.has_rsample
107+
108+
def rsample(self, shape: Size = ()):
109+
x = self.base.rsample(shape)
110+
111+
for t in reversed(self.transforms):
112+
x = t.inv(x)
113+
114+
return x
115+
116+
def sample(self, shape: Size = ()):
117+
with torch.no_grad():
118+
x = self.base.sample(shape)
119+
120+
for t in reversed(self.transforms):
121+
x = t.inv(x)
122+
123+
return x
64124

65125

66126
class Joint(Distribution):

0 commit comments

Comments
 (0)