From e7e6ce13880a18e881c5ccbf6243866edf7dc0d9 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Thu, 2 Mar 2017 18:11:44 -0800 Subject: [PATCH 1/6] Add convolution based interp and gradient. It shall give idential result to numpy's version, but this computes the matrix, and thus is differentiable. --- autograd/numpy/__init__.py | 1 + autograd/numpy/numpy_interp.py | 53 ++++++++++++++++++++++++++++++++++ tests/test_numpy_interp.py | 19 ++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 autograd/numpy/numpy_interp.py create mode 100644 tests/test_numpy_interp.py diff --git a/autograd/numpy/__init__.py b/autograd/numpy/__init__.py index e1e22fb2c..0f75767d8 100644 --- a/autograd/numpy/__init__.py +++ b/autograd/numpy/__init__.py @@ -6,3 +6,4 @@ from . import linalg from . import fft from . import random +from .numpy_interp import interp diff --git a/autograd/numpy/numpy_interp.py b/autograd/numpy/numpy_interp.py new file mode 100644 index 000000000..244b42bcc --- /dev/null +++ b/autograd/numpy/numpy_interp.py @@ -0,0 +1,53 @@ +import numpy as np +from autograd import numpy as anp +from autograd import primitive + +def interp(x, xp, yp, left=None, right=None): + """ Differentiable against yp """ + if left is None: + left = yp[0] + if right is None: right = yp[-1] + + xp = anp.concatenate([[xp[0]], xp, [xp[-1]]]) + yp = anp.concatenate([anp.array([left]), yp, anp.array([right])]) + + m = make_matrix(x, xp) + y = anp.inner(m, yp) + return y + +def W(r, D): + """ Convolution kernel for linear interpolation. + D is the differences of xp. + """ + mask = D == 0 + D[mask] = 1.0 + Wleft = 1.0 + r[1:] / D + Wright = 1.0 - r[:-1] / D + # edges + Wleft = np.where(mask, 0, Wleft) + Wright = np.where(mask, 0, Wright) + Wleft = np.concatenate([[0], Wleft]) + Wright = np.concatenate([Wright, [0]]) + W = np.where(r < 0, Wleft, Wright) + W = np.where(r == 0, 1.0, W) + W = np.where(W < 0, 0, W) + return W + +def make_matrix(x, xp): + D = np.diff(xp) + w = [] + v0 = np.zeros(len(xp)) + v0[0] = 1.0 + v1 = np.zeros(len(xp)) + v1[-1] = 1.0 + for xi in x: + # left, use left + if xi < xp[0]: v = v0 + # right , use right + elif xi > xp[-1]: v = v1 + else: + v = W(xi - xp, D) + v[0] = 0 + v[-1] = 0 + w.append(v) + return np.array(w) diff --git a/tests/test_numpy_interp.py b/tests/test_numpy_interp.py new file mode 100644 index 000000000..2ff8bde26 --- /dev/null +++ b/tests/test_numpy_interp.py @@ -0,0 +1,19 @@ +from __future__ import absolute_import +import warnings + +import autograd.numpy as np +import autograd.numpy.random as npr +from autograd.util import * +from autograd import grad + +npr.seed(1) + +def test_interp(): + x = np.arange(10) * 1.0 + xp = np.arange(10) * 1.0 + yp = np.arange(10) * 1.0 + def fun(yp): return to_scalar(np.interp(x, xp, yp)) + def dfun(yp): return to_scalar(grad(fun)(yp)) + print(fun(yp), dfun(yp)) + check_grads(fun, yp) + check_grads(dfun, yp) From 1349f3448bb307debc24bc82baccfc8bc0b289d9 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Thu, 2 Mar 2017 18:18:57 -0800 Subject: [PATCH 2/6] Update test cases to include correctness tests. --- tests/test_numpy_interp.py | 39 +++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/tests/test_numpy_interp.py b/tests/test_numpy_interp.py index 2ff8bde26..bf8ddb6bc 100644 --- a/tests/test_numpy_interp.py +++ b/tests/test_numpy_interp.py @@ -6,14 +6,47 @@ from autograd.util import * from autograd import grad +from numpy import interp as ninterp +from numpy.testing import assert_allclose + npr.seed(1) def test_interp(): - x = np.arange(10) * 1.0 + x = np.arange(-20, 20, 0.1) xp = np.arange(10) * 1.0 - yp = np.arange(10) * 1.0 + yp = xp ** 0.5 + npr.normal(size=xp.shape) def fun(yp): return to_scalar(np.interp(x, xp, yp)) def dfun(yp): return to_scalar(grad(fun)(yp)) - print(fun(yp), dfun(yp)) + + check_grads(fun, yp) + check_grads(dfun, yp) + +def test_interp_edge(): + x = np.arange(-20, 20, 0.1) + xp = np.arange(10) * 1.0 + yp = xp ** 0.5 + npr.normal(size=xp.shape) + def fun(yp): return to_scalar(np.interp(x, xp, yp, left=-1, right=-1)) + def dfun(yp): return to_scalar(grad(fun)(yp)) + check_grads(fun, yp) check_grads(dfun, yp) + +def test_interp_correctness(): + x = np.arange(-100, 100, 0.1) + xp = np.arange(10) * 1.0 + yp = xp ** 2 + + y1 = ninterp(x, xp, yp) + y2 = np.interp(x, xp, yp) + + assert_allclose(y1, y2) + +def test_interp_correctness_edge(): + x = np.arange(-100, 100, 0.1) + xp = np.arange(10) * 1.0 + yp = xp ** 2 + + y1 = ninterp(x, xp, yp, left=-1, right=-1) + y2 = np.interp(x, xp, yp, left=-1, right=-1) + + assert_allclose(y1, y2) From a311c637ec976baaecb14f85de3a22b5fe7ba3b4 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Thu, 2 Mar 2017 18:19:43 -0800 Subject: [PATCH 3/6] fix randomness per test --- tests/test_numpy_interp.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_numpy_interp.py b/tests/test_numpy_interp.py index bf8ddb6bc..064bdd7e5 100644 --- a/tests/test_numpy_interp.py +++ b/tests/test_numpy_interp.py @@ -9,11 +9,10 @@ from numpy import interp as ninterp from numpy.testing import assert_allclose -npr.seed(1) - def test_interp(): x = np.arange(-20, 20, 0.1) xp = np.arange(10) * 1.0 + npr.seed(1) yp = xp ** 0.5 + npr.normal(size=xp.shape) def fun(yp): return to_scalar(np.interp(x, xp, yp)) def dfun(yp): return to_scalar(grad(fun)(yp)) @@ -24,6 +23,7 @@ def dfun(yp): return to_scalar(grad(fun)(yp)) def test_interp_edge(): x = np.arange(-20, 20, 0.1) xp = np.arange(10) * 1.0 + npr.seed(1) yp = xp ** 0.5 + npr.normal(size=xp.shape) def fun(yp): return to_scalar(np.interp(x, xp, yp, left=-1, right=-1)) def dfun(yp): return to_scalar(grad(fun)(yp)) @@ -34,7 +34,8 @@ def dfun(yp): return to_scalar(grad(fun)(yp)) def test_interp_correctness(): x = np.arange(-100, 100, 0.1) xp = np.arange(10) * 1.0 - yp = xp ** 2 + npr.seed(1) + yp = xp ** 0.5 + npr.normal(size=xp.shape) y1 = ninterp(x, xp, yp) y2 = np.interp(x, xp, yp) @@ -44,7 +45,8 @@ def test_interp_correctness(): def test_interp_correctness_edge(): x = np.arange(-100, 100, 0.1) xp = np.arange(10) * 1.0 - yp = xp ** 2 + npr.seed(1) + yp = xp ** 0.5 + npr.normal(size=xp.shape) y1 = ninterp(x, xp, yp, left=-1, right=-1) y2 = np.interp(x, xp, yp, left=-1, right=-1) From 7c80e2c283db78221238a4f5f0da291dbf769095 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Fri, 3 Mar 2017 09:24:52 -0800 Subject: [PATCH 4/6] attempt to fix Python 2.7 errors. import as done in numpy.numpy_grads. --- autograd/numpy/numpy_interp.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/autograd/numpy/numpy_interp.py b/autograd/numpy/numpy_interp.py index 244b42bcc..be45a5144 100644 --- a/autograd/numpy/numpy_interp.py +++ b/autograd/numpy/numpy_interp.py @@ -1,11 +1,9 @@ -import numpy as np -from autograd import numpy as anp -from autograd import primitive +from autograd.core import primitive +from . import numpy_wrapper as anp def interp(x, xp, yp, left=None, right=None): """ Differentiable against yp """ - if left is None: - left = yp[0] + if left is None: left = yp[0] if right is None: right = yp[-1] xp = anp.concatenate([[xp[0]], xp, [xp[-1]]]) @@ -15,6 +13,10 @@ def interp(x, xp, yp, left=None, right=None): y = anp.inner(m, yp) return y + +# The following are internal functions +import numpy as np + def W(r, D): """ Convolution kernel for linear interpolation. D is the differences of xp. From 6929c94896f9dacd5ef6c3dd746c3944e5bb34ac Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Tue, 7 Mar 2017 15:38:55 -0800 Subject: [PATCH 5/6] cleanup and add periodic support. --- autograd/numpy/__init__.py | 2 +- ...{numpy_interp.py => numpy_grads_interp.py} | 21 +++++++++++--- tests/test_numpy_interp.py | 29 +++++++------------ 3 files changed, 28 insertions(+), 24 deletions(-) rename autograd/numpy/{numpy_interp.py => numpy_grads_interp.py} (63%) diff --git a/autograd/numpy/__init__.py b/autograd/numpy/__init__.py index 0f75767d8..ff49fb35e 100644 --- a/autograd/numpy/__init__.py +++ b/autograd/numpy/__init__.py @@ -1,9 +1,9 @@ from __future__ import absolute_import from . import numpy_wrapper from . import numpy_grads +from . import numpy_grads_interp from . import numpy_extra from .numpy_wrapper import * from . import linalg from . import fft from . import random -from .numpy_interp import interp diff --git a/autograd/numpy/numpy_interp.py b/autograd/numpy/numpy_grads_interp.py similarity index 63% rename from autograd/numpy/numpy_interp.py rename to autograd/numpy/numpy_grads_interp.py index be45a5144..144da3e5f 100644 --- a/autograd/numpy/numpy_interp.py +++ b/autograd/numpy/numpy_grads_interp.py @@ -1,20 +1,33 @@ -from autograd.core import primitive from . import numpy_wrapper as anp -def interp(x, xp, yp, left=None, right=None): - """ Differentiable against yp """ +def _interp_vjp(x, xp, yp, left, right, period, g): + from autograd import vector_jacobian_product + func = vector_jacobian_product(_interp, argnum=2) + return func(x, xp, yp, left, right, period, g) + +def _interp(x, xp, yp, left=None, right=None, period=None): + """ A partial rewrite of interp that is differentiable against yp """ + if period is not None: + xp = anp.concatenate([[xp[-1] - period], xp, [xp[0] + period]]) + yp = anp.concatenate([anp.array([yp[-1]]), yp, anp.array([yp[0]])]) + return _interp(x % period, xp, yp, left, right, None) + if left is None: left = yp[0] if right is None: right = yp[-1] xp = anp.concatenate([[xp[0]], xp, [xp[-1]]]) - yp = anp.concatenate([anp.array([left]), yp, anp.array([right])]) + yp = anp.concatenate([anp.array([left]), yp, anp.array([right])]) m = make_matrix(x, xp) y = anp.inner(m, yp) return y +anp.interp.defvjp(lambda g, ans, vs, gvs, x, xp, yp, left=None, right=None, period=None: + _interp_vjp(x, xp, yp, left, right, period, g), argnum=2) + # The following are internal functions + import numpy as np def W(r, D): diff --git a/tests/test_numpy_interp.py b/tests/test_numpy_interp.py index 064bdd7e5..0e7f7a77b 100644 --- a/tests/test_numpy_interp.py +++ b/tests/test_numpy_interp.py @@ -6,7 +6,6 @@ from autograd.util import * from autograd import grad -from numpy import interp as ninterp from numpy.testing import assert_allclose def test_interp(): @@ -27,28 +26,20 @@ def test_interp_edge(): yp = xp ** 0.5 + npr.normal(size=xp.shape) def fun(yp): return to_scalar(np.interp(x, xp, yp, left=-1, right=-1)) def dfun(yp): return to_scalar(grad(fun)(yp)) - check_grads(fun, yp) check_grads(dfun, yp) -def test_interp_correctness(): - x = np.arange(-100, 100, 0.1) +def test_interp_period(): + x = np.arange(-20, 20, 0.5) xp = np.arange(10) * 1.0 npr.seed(1) yp = xp ** 0.5 + npr.normal(size=xp.shape) + def fun(yp): return to_scalar(np.interp(x, xp, yp, period=10)) + def dfun(yp): return to_scalar(grad(fun)(yp)) - y1 = ninterp(x, xp, yp) - y2 = np.interp(x, xp, yp) - - assert_allclose(y1, y2) - -def test_interp_correctness_edge(): - x = np.arange(-100, 100, 0.1) - xp = np.arange(10) * 1.0 - npr.seed(1) - yp = xp ** 0.5 + npr.normal(size=xp.shape) - - y1 = ninterp(x, xp, yp, left=-1, right=-1) - y2 = np.interp(x, xp, yp, left=-1, right=-1) - - assert_allclose(y1, y2) + print('xp', xp) + print('yp', yp) + print('x', x % 10) + print('y', np.interp(x, xp, yp, period=10)) + check_grads(fun, yp) + check_grads(dfun, yp) From 2e089ae5027265f73baaf47b74a64a2d90ba20fe Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Tue, 7 Mar 2017 15:39:20 -0800 Subject: [PATCH 6/6] remove prints. --- tests/test_numpy_interp.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_numpy_interp.py b/tests/test_numpy_interp.py index 0e7f7a77b..2e43db595 100644 --- a/tests/test_numpy_interp.py +++ b/tests/test_numpy_interp.py @@ -37,9 +37,5 @@ def test_interp_period(): def fun(yp): return to_scalar(np.interp(x, xp, yp, period=10)) def dfun(yp): return to_scalar(grad(fun)(yp)) - print('xp', xp) - print('yp', yp) - print('x', x % 10) - print('y', np.interp(x, xp, yp, period=10)) check_grads(fun, yp) check_grads(dfun, yp)