Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions luminoth/datasets/object_detection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from luminoth.datasets.base_dataset import BaseDataset
from luminoth.utils.image import (
resize_image_fixed, resize_image, flip_image, random_patch, random_resize,
random_distortion, expand
random_distortion, expand, fisheye
)

DATA_AUGMENTATION_STRATEGIES = {
'flip': flip_image,
'patch': random_patch,
'resize': random_resize,
'distortion': random_distortion,
'expand': expand
'expand': expand,
'fisheye': fisheye
}


Expand Down
129 changes: 128 additions & 1 deletion luminoth/utils/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import tensorflow as tf

from luminoth.utils.bbox_transform_tf import clip_boxes


Expand Down Expand Up @@ -618,3 +617,131 @@ def expand(image, bboxes=None, fill=0, min_ratio=1, max_ratio=4, seed=None):
if bboxes is not None:
return_dict['bboxes'] = bbox_adjusted
return return_dict


def fisheye(image, bboxes=None, min_top_padding=0.1, max_top_padding=1.0, min_bottom_padding=0.1,
max_bottom_padding=1.0, min_left_padding=0., max_left_padding=0.5, min_right_padding=0.,
max_right_padding=0.5):
"""
Applies a fisheye effect transformation, taking the top-left corner as the fisheye
center. This generates a quarter-fisheye image.

The fisheye center can be changed by flipping the image.

Args:
image: Tensor with image of shape (H, W, 3).
bboxes: Optional Tensor with bounding boxes with shape (num_bboxes, 5).
where we have (x_min, y_min, x_max, y_max, label) for each one.
min_top_padding: Min fraction of image height to determine 0-padding on top of it.
max_top_padding: Max fraction of image height to determine 0-padding on top of it.
min_bottom_padding: Min fraction of image height to determine 0-padding below of it.
max_bottom_padding: Max fraction of image height to determine 0-padding below of it.
min_left_padding: Min fraction of image width to determine 0-padding on left side.
max_left_padding: Max fraction of image widtht to determine 0-padding on left side.
min_right_padding: Min fraction of image width to determine 0-padding on right side.
max_right_padding: Max fraction of image width to determine 0-padding on right side.


Returns:
Dictionary containing:
image: Tensor with transformed out image.
bboxes: Tensor with transformed out bounding boxes with shape
(num_bboxes, 5).
"""

def interpolate_bilinear(img, y, x):
return img[tf.floor(y), tf.floor(x)] * (1 - (y - tf.floor(y))) * (1 - (x - tf.floor(x)))
+ img[tf.floor(y), tf.floor(x) + 1] * (1 - (y - tf.floor(y))) * (x - tf.floor(x))
+ img[tf.floor(y) + 1, tf.floor(x)] * (y - tf.floor(y)) * (1 - (x - tf.floor(x)))
+ img[tf.floor(y) + 1, tf.floor(x) + 1] * (y - tf.floor(y)) * (x - tf.floor(x))

# Generate padding values
top_padding = tf.random_uniform([], min_top_padding, max_top_padding)
bottom_padding = tf.random_uniform([], min_bottom_padding, max_bottom_padding)
left_padding = tf.random_uniform([], min_left_padding, max_left_padding)
right_padding = tf.random_uniform([], min_right_padding, max_right_padding)

image_shape = tf.to_float(tf.shape(image))
height = image_shape[0]
width = image_shape[1]

# Compute absolute length of paddings
top_padding, bottom_padding, left_padding, right_padding = (
tf.cast(height * top_padding, tf.int32),
tf.cast(height * bottom_padding, tf.int32),
tf.cast(width * left_padding, tf.int32),
tf.cast(width * right_padding, tf.int32)
)
vertical_paddings_tensor = tf.stack([top_padding, bottom_padding])
horizontal_paddings_tensor = tf.stack([left_padding, right_padding])

# Apply paddings on top, bottom and sides of image
image = tf.pad(
image, tf.stack([vertical_paddings_tensor, horizontal_paddings_tensor, [0, 0]]),
'constant', constant_values=0
)

# Adjust bboxes to paddings
x_min, y_min, x_max, y_max, label = tf.unstack(bboxes, axis=1)
y_min += top_padding
y_max += top_padding
x_min += left_padding
x_max += left_padding
bboxes = tf.stack([x_min, y_min, x_max, y_max, label], axis=1)

# Recompute image height and width with added padings
height = tf.cast(height, tf.int32) + top_padding + bottom_padding
width = width + tf.cast(left_padding + right_padding, tf.float32)

out_width = out_height = tf.cast(height - bottom_padding, tf.float32)

# Apply fisheye transformation
# Map each pixel from source image, taking y as radius and x as cos(theta) times its width.
X = tf.range(out_width, dtype=tf.float32)
Y = tf.range(out_height, dtype=tf.float32)

X_2 = tf.square(X)
Y_2 = tf.square(Y)
# Compute radius of each coordinate
# One-pad vectors to compute sum of squares as a matrix product
X_2 = tf.stack([X_2, tf.ones(tf.cast(out_width, tf.int32), tf.float32)])
Y_2 = tf.stack([tf.ones(tf.cast(out_height, tf.int32), tf.float32), Y_2], axis=1)
yS = tf.sqrt(tf.matmul(Y_2, X_2))

# Compute angle of each coordinate
theta = tf.atan(tf.matmul(tf.expand_dims(Y, 1), tf.expand_dims(1/X, 0)))
xS = tf.cos(theta) * width

coords = tf.stack([xS, yS], axis=2)
# Expand image dims since resampler expects shape [batch_size, data_height, data_width,
# data_channels]
out_image = tf.contrib.resampler.resampler(tf.expand_dims(image, 0), tf.expand_dims(coords, 0))

# Transform bboxes. The transformed bboxes are not squares since they get deformed,
# so we generate square bboxes that contains the transformed ones
x_min, y_min, x_max, y_max, label = tf.unstack(bboxes, axis=1)
x_min, y_min, x_max, y_max = [tf.cast(x, tf.float32) for x in [x_min, y_min, x_max, y_max]]

# Apply transformation to bboxes, getting polar coordinates
r_min_dst = y_min
r_max_dst = y_max
theta_min_dst = tf.acos(x_min / width)
theta_max_dst = tf.acos(x_max / width)

# Transform to cartesian
x_min_dst = r_min_dst * (x_min / width) # Simplify r_min_dst * tf.cos(tf.acos(x_min / width))
x_max_dst = r_max_dst * (x_max / width)
y_min_dst = r_min_dst * tf.sin(theta_max_dst)
y_max_dst = r_max_dst * tf.sin(theta_min_dst)

x_min, y_min, x_max, y_max = [tf.cast(x, tf.int32) for x in [x_min_dst,
y_min_dst, x_max_dst, y_max_dst]]
bboxes = tf.stack([x_min, y_min, x_max, y_max, label], axis=1)

return_dict = {'image': tf.reshape(
out_image, [height - bottom_padding, height - bottom_padding, 3]
)
}
if bboxes is not None:
return_dict['bboxes'] = bboxes
return return_dict