Replies: 4 comments 6 replies
-
Hi - the pytorch |
Beta Was this translation helpful? Give feedback.
-
Even though there hasn't been an update in a while, I'm sharing my implementation of grid_sample in JAX here. I have tested it, and it works the same as the PyTorch version. It currently only supports 2D input. While it outperforms the PyTorch version on CPU (it's about 5× faster), it is roughly 10× slower on GPU. Please check it out, and let me know if you have any feedback on this code. |
Beta Was this translation helpful? Give feedback.
-
Seems like a good potential use of FFI interface to call a C++ / CUDA kernel. |
Beta Was this translation helpful? Give feedback.
-
Here's an implementation that works for any dimension. However, this is bilinear only, and also it uses un-normalized coordinates, unlike torch. def _retrieve_value_at(img, loc, out_of_bound_value=0):
iloc = jnp.floor(loc).astype(int)
res = loc - iloc
offsets = jnp.asarray(
[[(i >> j) % 2 for j in range(len(loc))] for i in range(2 ** len(loc))]
)
ilocs = jnp.swapaxes(iloc + offsets, 0, 1)
weight = jnp.prod(res * (offsets == 1) + (1 - res) * (offsets == 0), axis=1)
max_indices = jnp.asarray(img.shape)[: len(loc), None]
values = jnp.where(
(ilocs >= 0).all(axis=0) & (ilocs < max_indices).all(axis=0),
jnp.swapaxes(img[tuple(ilocs)], 0, -1),
out_of_bound_value,
)
value = (values * weight).sum(axis=-1)
return value
def sub_pixel_samples(
img: ArrayLike,
locs: ArrayLike,
out_of_bound_value: float = 0,
align_corners: bool = False,
) -> Array:
"""Retrieve image values as non-integer locations by interpolation
Args:
img: Array of shape [D1,D2,..,Dk, ...]
locs: Array of shape [d1,d2,..,dn, k]
out_of_bound_value:
align_corners:
Returns:
values: [d1,d2,..,dn, ...], float
"""
loc_shape = locs.shape
img_shape = img.shape
d_loc = loc_shape[-1]
locs = jnp.asarray(locs)
img = jnp.asarray(img)
if align_corners:
locs = locs + 0.5
img = img.reshape(img_shape[:d_loc] + (-1,))
locs = locs.reshape(-1, d_loc)
op = partial(_retrieve_value_at, out_of_bound_value=out_of_bound_value)
values = jax.vmap(op, in_axes=(None, 0))(img, locs)
out_shape = loc_shape[:-1] + img_shape[d_loc:]
values = values.reshape(out_shape)
return values |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
It would be great to have an equivalent of torch.functional.grid_sample in jax. This is widely used in 3D vision (view synthesis, re-projections to other camera positions, etc.). My understanding is, that currently, to do a similar thing in JAX, one would have to implement this from scratch like this Tensorflow example, which seems verbose and slow, whereas the PyTorch version seems to do this in one CUDA kernel.
Thank you for your consideration!
Beta Was this translation helpful? Give feedback.
All reactions