Skip to content

Gradients of the resampler do not match finite differences on a integer-pixel grid.  #2535

Open
@andrevitorelli

Description

@andrevitorelli

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): CentOS Linux release 8.3.2011
  • TensorFlow version and how it was installed (source or binary): binary
  • TensorFlow-Addons version and how it was installed (source or binary): binary
  • Python version: 3.8.3
  • Is GPU used? (yes/no): yes (but results are the same if not using)

Describe the bug
tensorflow_addons.image.resampler gradients don't match numdifftools when using integer pixel warps, but do on non-integer pixel warps.

Code to reproduce the issue
If we do:

import numpy as np
from matplotlib import pyplot as plt
from tensorflow_addons.image import resampler
from scipy.misc import face
import numdifftools
import tensorflow as tf

#get an image
image = face(gray=True)[-512:-512+128,-512:-512+128].astype('float32')
image_tf = tf.convert_to_tensor(image.reshape([1,128,128, 1]))

#set a warp
warp = np.stack(np.meshgrid(np.arange(128), np.arange(128)), axis=-1).astype('float32')
warp_tf = tf.convert_to_tensor(warp.reshape([1,128,128,2]))

#define a shift
shift = tf.zeros([1,2])

#calculate derivatives via tf.GradientTape
with tf.GradientTape() as tape:
    tape.watch(shift)
    ws = tf.reshape(shift,[1,1,1,2]) + warp_tf
    o = resampler(image_tf, ws)
autodiff_jacobian = tape.batch_jacobian(o, shift) 

#calculate derivatives via numdifftools
def fn(shift):
    shift = tf.convert_to_tensor(shift.astype('float32'))
    ws = tf.reshape(shift,[1,1,1,2]) + warp_tf
    o = resampler(image_tf, ws)
    return o.numpy().flatten()

numdiff_jacobian = numdifftools.Jacobian(fn, order=4, step=0.04)
numdiff_jacobian = numdiff_jacobian(np.zeros([2])).reshape([128,128,2])

#display residuals
plt.figure(figsize=(15,5))
plt.subplot(121)
residual1 = abs(autodiff_jacobian[0,:,:,0,0] - numdiff_jacobian[:,:,0])
plt.imshow(residual1[2:-2,2:-2]) ; plt.colorbar()
plt.subplot(122)
residual2 = abs(autodiff_jacobian[0,:,:,0,1] - numdiff_jacobian[:,:,1])
plt.imshow(residual2[2:-2,2:-2]) ; plt.colorbar()

We see large residuals (on the same order of the pixel values):

integer_pixels

But if we do

#set a warp
warp = np.stack(np.meshgrid(np.arange(128), np.arange(128)), axis=-1).astype('float32')
warp_tf = tf.convert_to_tensor(warp.reshape([1,128,128,2])+.5) #add a half-step 

#define a shift
shift = tf.zeros([1,2])

#calculate derivatives via tf.GradientTape
with tf.GradientTape() as tape:
    tape.watch(shift)
    ws = tf.reshape(shift,[1,1,1,2]) + warp_tf
    o = resampler(image_tf, ws)
autodiff_jacobian = tape.batch_jacobian(o, shift) 

#calculate derivatives via numdifftools
def fn(shift):
    shift = tf.convert_to_tensor(shift.astype('float32'))
    ws = tf.reshape(shift,[1,1,1,2]) + warp_tf
    o = resampler(image_tf, ws)
    return o.numpy().flatten()

numdiff_jacobian = numdifftools.Jacobian(fn, order=4, step=0.04)
numdiff_jacobian = numdiff_jacobian(np.zeros([2])).reshape([128,128,2])

#display residuals
plt.figure(figsize=(15,5))
plt.subplot(121)
residual1 = abs(autodiff_jacobian[0,:,:,0,0] - numdiff_jacobian[:,:,0])
plt.imshow(residual1[2:-2,2:-2]) ; plt.colorbar()
plt.subplot(122)
residual2 = abs(autodiff_jacobian[0,:,:,0,1] - numdiff_jacobian[:,:,1])
plt.imshow(residual2[2:-2,2:-2]) ; plt.colorbar()

The residuals are now around 3 magnitudes less:
half_pixels

Other info / logs
We (@EiffL, @Dr-Zero, and me) are currently working on a re-implementation of the resampler kernels to have better interpolators available, to use in our project autometacal, and its requirement GalFlow - so, there are chances that these will change. Any insights here would be helpful, thank you.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions