-
Notifications
You must be signed in to change notification settings - Fork 50
[Feature] Support SWD and MS-SSIM metrics #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
LeoXing1996
wants to merge
12
commits into
open-mmlab:main
Choose a base branch
from
LeoXing1996:leoxing/add-swd-and-ms-ssim
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
60a2037
support ms-ssim
LeoXing1996 04ba8d6
update docstring and compulate logic of ms-ssim
LeoXing1996 63a9b0d
support swd metric
LeoXing1996 204f41c
revise unit test of ms-ssim
LeoXing1996 80501b5
revise docstring for SWD and MS-SSIM
LeoXing1996 ea7696a
add more unit test for SWD
LeoXing1996 2988a63
add shape checking for MS-SSIM input and revise docstring as comment
LeoXing1996 fb7824d
support groundtruths input in MS-SSIM for reconstructive tasks
LeoXing1996 5c035cf
do not support groundtruths=None is MS-SSIM
LeoXing1996 a965fe2
revise unit test of swd
LeoXing1996 1f4675b
move word ignore setting to setup.cfg
LeoXing1996 c9bc955
revise ms-ssim as comment
LeoXing1996 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,295 @@ | ||||||||||
# Copyright (c) OpenMMLab. All rights reserved. | ||||||||||
import numpy as np | ||||||||||
from scipy import signal | ||||||||||
from typing import Dict, List, Sequence, Tuple | ||||||||||
|
||||||||||
from mmeval.core import BaseMetric | ||||||||||
from .utils.image_transforms import reorder_image | ||||||||||
|
||||||||||
|
||||||||||
class MultiScaleStructureSimilarity(BaseMetric): | ||||||||||
"""MS-SSIM (Multi-Scale Structure Similarity) metric. | ||||||||||
|
||||||||||
Ref: | ||||||||||
This class implements Multi-Scale Structural Similarity (MS-SSIM) Image | ||||||||||
Quality Assessment according to Zhou Wang's paper, "Multi-scale structural | ||||||||||
similarity for image quality assessment" (2003). | ||||||||||
Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf | ||||||||||
|
||||||||||
Author's MATLAB implementation: | ||||||||||
http://www.cns.nyu.edu/~lcv/ssim/msssim.zip | ||||||||||
|
||||||||||
PGGAN's implementation: | ||||||||||
https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py | ||||||||||
|
||||||||||
Args: | ||||||||||
input_order (str): Whether the input order is 'HWC' or 'CHW'. | ||||||||||
Defaults to 'HWC'. | ||||||||||
max_val (int): the dynamic range of the images (i.e., the difference | ||||||||||
between the maximum and the minimum allowed values). | ||||||||||
Defaults to 255. | ||||||||||
filter_size (int): Size of blur kernel to use (will be reduced for | ||||||||||
small images). Defaults to 11. | ||||||||||
filter_sigma (float): Standard deviation for Gaussian blur kernel (will | ||||||||||
be reduced for small images). Defaults to 1.5. | ||||||||||
k1 (float): Constant used to maintain stability in the SSIM calculation | ||||||||||
(0.01 in the original paper). Defaults to 0.01. | ||||||||||
k2 (float): Constant used to maintain stability in the SSIM calculation | ||||||||||
(0.03 in the original paper). Defaults to 0.03. | ||||||||||
weights (List[float]): List of weights for each level. Defaults to | ||||||||||
[0.0448, 0.2856, 0.3001, 0.2363, 0.1333]. Noted that the default | ||||||||||
weights don't sum to 1.0 but do match the paper / matlab code. | ||||||||||
**kwargs: Keyword parameters passed to :class:`BaseMetric`. | ||||||||||
|
||||||||||
Examples: | ||||||||||
|
||||||||||
>>> from mmeval import MultiScaleStructureSimilarity as MS_SSIM | ||||||||||
>>> import numpy as np | ||||||||||
>>> | ||||||||||
>>> ms_ssim = MS_SSIM() | ||||||||||
>>> preds = [np.random.randint(0, 255, size=(3, 32, 32)) for _ in range(4)] # noqa | ||||||||||
>>> gts = [np.random.randint(0, 255, size=(3, 32, 32)) for _ in range(4)] # noqa | ||||||||||
>>> ms_ssim(preds, gts) # doctest: +ELLIPSIS | ||||||||||
{'ms_ssim': ...} | ||||||||||
""" | ||||||||||
|
||||||||||
def __init__(self, | ||||||||||
input_order: str = 'CHW', | ||||||||||
max_val: int = 255, | ||||||||||
filter_size: int = 11, | ||||||||||
filter_sigma: float = 1.5, | ||||||||||
k1: float = 0.01, | ||||||||||
k2: float = 0.03, | ||||||||||
weights: List[float] = [ | ||||||||||
0.0448, 0.2856, 0.3001, 0.2363, 0.1333 | ||||||||||
], | ||||||||||
**kwargs) -> None: | ||||||||||
super().__init__(**kwargs) | ||||||||||
|
||||||||||
assert input_order.upper() in [ | ||||||||||
'CHW', 'HWC' | ||||||||||
], (f'Wrong input_order {input_order}. Supported input_orders are ' | ||||||||||
'"HWC" and "CHW"') | ||||||||||
self.input_order = input_order | ||||||||||
|
||||||||||
self.max_val = max_val | ||||||||||
self.filter_size = filter_size | ||||||||||
self.filter_sigma = filter_sigma | ||||||||||
self.k1 = k1 | ||||||||||
self.k2 = k2 | ||||||||||
self.weights = np.array(weights) | ||||||||||
|
||||||||||
def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501 | ||||||||||
"""Add a bunch of images to calculate metric result. | ||||||||||
|
||||||||||
Args: | ||||||||||
predictions (Sequence[np.ndarray]): Predictions of the model. | ||||||||||
The length of `predictions` must be same as `groundtruths`. | ||||||||||
The width and height of each element must be divisible by 2 ** | ||||||||||
num_scale (`self.weights.size`). The channel order of each | ||||||||||
element should align with `self.input_order` and the range | ||||||||||
should be [0, 255]. | ||||||||||
groundtruths (Sequence[np.ndarray], optional): Groundtruth of the | ||||||||||
model. The number of elements in the Sequence must be same as | ||||||||||
`predictions`, and the width and height of each element must | ||||||||||
be divisible by 2 ** num_scale (`self.weights.size`). The | ||||||||||
channel order of each element should align with | ||||||||||
`self.input_order` and the range should be [0, 255]. | ||||||||||
Defaults to None. | ||||||||||
""" | ||||||||||
assert len(predictions) == len(groundtruths), ( | ||||||||||
'The length of "predictions" and "groundtruths" must be ' | ||||||||||
'same.') | ||||||||||
half1, half2 = predictions, groundtruths | ||||||||||
|
||||||||||
half1 = [reorder_image(samp, self.input_order) for samp in half1] | ||||||||||
half2 = [reorder_image(samp, self.input_order) for samp in half2] | ||||||||||
least_size = 2**self.weights.size | ||||||||||
assert all([ | ||||||||||
sample.shape[0] % least_size == 0 for sample in half1 | ||||||||||
]), ('The height and width of each sample must be divisible by ' | ||||||||||
f'{least_size} (2 ** len(self.weights.size)).') | ||||||||||
assert all([ | ||||||||||
sample.shape[0] % least_size == 0 for sample in half2 | ||||||||||
]), ('The height and width of each sample must be divisible by ' | ||||||||||
f'{least_size} (2 ** self.weights.size).') | ||||||||||
|
||||||||||
half1 = np.stack(half1, axis=0).astype(np.uint8) | ||||||||||
half2 = np.stack(half2, axis=0).astype(np.uint8) | ||||||||||
|
||||||||||
self._results += self.compute_ms_ssim(half1, half2) | ||||||||||
|
||||||||||
def compute_metric(self, results: List[np.float64]) -> Dict[str, float]: | ||||||||||
"""Compute the MS-SSIM metric. | ||||||||||
|
||||||||||
This method would be invoked in ``BaseMetric.compute`` after | ||||||||||
distributed synchronization. | ||||||||||
|
||||||||||
Args: | ||||||||||
results (List[np.float64]): A list that consisting the PSNR score. | ||||||||||
This list has already been synced across all ranks. | ||||||||||
|
||||||||||
Returns: | ||||||||||
Dict[str, float]: The computed PSNR metric. | ||||||||||
""" | ||||||||||
return {'ms-ssim': float(np.array(results).mean())} | ||||||||||
|
||||||||||
def compute_ms_ssim(self, img1: np.array, img2: np.array) -> List[float]: | ||||||||||
"""Calculate MS-SSIM (multi-scale structural similarity). | ||||||||||
|
||||||||||
Args: | ||||||||||
img1 (ndarray): Images with range [0, 255] and order "NHWC". | ||||||||||
img2 (ndarray): Images with range [0, 255] and order "NHWC". | ||||||||||
|
||||||||||
Returns: | ||||||||||
np.ndarray: MS-SSIM score between `img1` and `img2` of shape (N, ). | ||||||||||
""" | ||||||||||
if img1.shape != img2.shape: | ||||||||||
raise RuntimeError( | ||||||||||
'Input images must have the same shape (%s vs. %s).' % | ||||||||||
(img1.shape, img2.shape)) | ||||||||||
if img1.ndim != 4: | ||||||||||
raise RuntimeError( | ||||||||||
'Input images must have four dimensions, not %d' % img1.ndim) | ||||||||||
|
||||||||||
levels = self.weights.size | ||||||||||
im1, im2 = (x.astype(np.float32) for x in [img1, img2]) | ||||||||||
mssim = [] | ||||||||||
mcs = [] | ||||||||||
for _ in range(levels): | ||||||||||
ssim, cs = self._ssim_for_multi_scale( | ||||||||||
im1, | ||||||||||
im2, | ||||||||||
max_val=self.max_val, | ||||||||||
filter_size=self.filter_size, | ||||||||||
filter_sigma=self.filter_sigma, | ||||||||||
k1=self.k1, | ||||||||||
k2=self.k2) | ||||||||||
mssim.append(ssim) | ||||||||||
mcs.append(cs) | ||||||||||
im1, im2 = (self._hox_downsample(x) for x in [im1, im2]) | ||||||||||
|
||||||||||
# Clip to zero. Otherwise we get NaNs. | ||||||||||
mssim = np.clip(np.asarray(mssim), 0.0, np.inf) | ||||||||||
mcs = np.clip(np.asarray(mcs), 0.0, np.inf) | ||||||||||
|
||||||||||
results = np.prod( | ||||||||||
mcs[:-1, :]**self.weights[:-1, np.newaxis], axis=0) * ( | ||||||||||
mssim[-1, :]**self.weights[-1]) | ||||||||||
return results.tolist() | ||||||||||
|
||||||||||
@staticmethod | ||||||||||
def _f_special_gauss(size: int, sigma: float) -> np.ndarray: | ||||||||||
r"""Return a circular symmetric gaussian kernel. | ||||||||||
|
||||||||||
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/2504c3f3cb98ca58751610ad61fa1097313152bd/metrics/ms_ssim.py#L25-L36 # noqa | ||||||||||
|
||||||||||
Args: | ||||||||||
size (int): Size of Gaussian kernel. | ||||||||||
sigma (float): Standard deviation for Gaussian blur kernel. | ||||||||||
|
||||||||||
Returns: | ||||||||||
np.ndarray: Gaussian kernel. | ||||||||||
""" | ||||||||||
radius = size // 2 | ||||||||||
offset = 0.0 | ||||||||||
start, stop = -radius, radius + 1 | ||||||||||
if size % 2 == 0: | ||||||||||
offset = 0.5 | ||||||||||
stop -= 1 | ||||||||||
x, y = np.mgrid[offset + start:stop, # type: ignore # noqa | ||||||||||
offset + start:stop] # type: ignore # noqa | ||||||||||
assert len(x) == size | ||||||||||
g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2))) | ||||||||||
return g / g.sum() | ||||||||||
|
||||||||||
@staticmethod | ||||||||||
def _hox_downsample(img: np.ndarray) -> np.ndarray: | ||||||||||
r"""Downsample images with factor equal to 0.5. | ||||||||||
|
||||||||||
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/2504c3f3cb98ca58751610ad61fa1097313152bd/metrics/ms_ssim.py#L110-L111 # noqa | ||||||||||
|
||||||||||
Args: | ||||||||||
img (np.ndarray): Images with order "NHWC". | ||||||||||
|
||||||||||
Returns: | ||||||||||
np.ndarray: Downsampled images with order "NHWC". | ||||||||||
""" | ||||||||||
return (img[:, 0::2, 0::2, :] + img[:, 1::2, 0::2, :] + | ||||||||||
img[:, 0::2, 1::2, :] + img[:, 1::2, 1::2, :]) * 0.25 | ||||||||||
C1rN09 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
def _ssim_for_multi_scale( | ||||||||||
self, | ||||||||||
img1: np.ndarray, | ||||||||||
img2: np.ndarray, | ||||||||||
max_val: int = 255, | ||||||||||
filter_size: int = 11, | ||||||||||
filter_sigma: float = 1.5, | ||||||||||
k1: float = 0.01, | ||||||||||
k2: float = 0.03) -> Tuple[np.ndarray, np.ndarray]: | ||||||||||
"""Calculate SSIM (structural similarity) and contrast sensitivity. | ||||||||||
|
||||||||||
Ref: | ||||||||||
Our implementation is based on PGGAN: | ||||||||||
https://github.com/tkarras/progressive_growing_of_gans/blob/2504c3f3cb98ca58751610ad61fa1097313152bd/metrics/ms_ssim.py#L38-L108 # noqa | ||||||||||
|
||||||||||
Args: | ||||||||||
img1 (np.ndarray): Images with range [0, 255] and order "NHWC". | ||||||||||
img2 (np.ndarray): Images with range [0, 255] and order "NHWC". | ||||||||||
max_val (int): the dynamic range of the images (i.e., the | ||||||||||
difference between the maximum the and minimum allowed | ||||||||||
values). Defaults to 255. | ||||||||||
filter_size (int): Size of blur kernel to use (will be reduced for | ||||||||||
small images). Defaults to 11. | ||||||||||
filter_sigma (float): Standard deviation for Gaussian blur kernel ( | ||||||||||
will be reduced for small images). Defaults to 1.5. | ||||||||||
k1 (float): Constant used to maintain stability in the SSIM | ||||||||||
calculation (0.01 in the original paper). Defaults to 0.01. | ||||||||||
k2 (float): Constant used to maintain stability in the SSIM | ||||||||||
calculation (0.03 in the original paper). Defaults to 0.03. | ||||||||||
|
||||||||||
Returns: | ||||||||||
tuple: Pair containing the mean SSIM and contrast sensitivity | ||||||||||
between `img1` and `img2`. | ||||||||||
Comment on lines
+252
to
+253
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
""" | ||||||||||
img1 = img1.astype(np.float32) | ||||||||||
img2 = img2.astype(np.float32) | ||||||||||
_, height, width, _ = img1.shape | ||||||||||
|
||||||||||
# Filter size can't be larger than height or width of images. | ||||||||||
size = min(filter_size, height, width) | ||||||||||
|
||||||||||
# Scale down sigma if a smaller filter size is used. | ||||||||||
sigma = size * filter_sigma / filter_size if filter_size else 0 | ||||||||||
|
||||||||||
if filter_size: | ||||||||||
window = np.reshape( | ||||||||||
self._f_special_gauss(size, sigma), (1, size, size, 1)) | ||||||||||
mu1 = signal.fftconvolve(img1, window, mode='valid') | ||||||||||
mu2 = signal.fftconvolve(img2, window, mode='valid') | ||||||||||
sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid') | ||||||||||
sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid') | ||||||||||
sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') | ||||||||||
else: | ||||||||||
# Empty blur kernel so no need to convolve. | ||||||||||
mu1, mu2 = img1, img2 | ||||||||||
sigma11 = img1 * img1 | ||||||||||
sigma22 = img2 * img2 | ||||||||||
sigma12 = img1 * img2 | ||||||||||
|
||||||||||
mu11 = mu1 * mu1 | ||||||||||
mu22 = mu2 * mu2 | ||||||||||
mu12 = mu1 * mu2 | ||||||||||
sigma11 -= mu11 | ||||||||||
sigma22 -= mu22 | ||||||||||
sigma12 -= mu12 | ||||||||||
|
||||||||||
# Calculate intermediate values used by both ssim and cs_map. | ||||||||||
c1 = (k1 * max_val)**2 | ||||||||||
c2 = (k2 * max_val)**2 | ||||||||||
v1 = 2.0 * sigma12 + c2 | ||||||||||
v2 = sigma11 + sigma22 + c2 | ||||||||||
ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2)), | ||||||||||
axis=(1, 2, 3)) # Return for each image individually. | ||||||||||
cs = np.mean(v1 / v2, axis=(1, 2, 3)) | ||||||||||
return ssim, cs |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.