import copy
import inspect
import mmcv
import numpy as np
from numpy import random
from mmdet.core import PolygonMasks
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
from ..builder import PIPELINES
try:
from imagecorruptions import corrupt
except ImportError:
corrupt = None
try:
import albumentations
from albumentations import Compose
except ImportError:
albumentations = None
Compose = None
@PIPELINES.register_module()
class Resize(object):
"""Resize images & bbox & mask.
This transform resizes the input image to some scale. Bboxes and masks are
then resized with the same scale factor. If the input dict contains the key
"scale", then the scale in the input dict is used, otherwise the specified
scale in the init method is used. If the input dict contains the key
"scale_factor" (if MultiScaleFlipAug does not give img_scale but
scale_factor), the actual scale will be computed by image shape and
scale_factor.
`img_scale` can either be a tuple (single-scale) or a list of tuple
(multi-scale). There are 3 multiscale modes:
- ``ratio_range is not None``: randomly sample a ratio from the ratio \
range and multiply it with the image scale.
- ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \
sample a scale from the multiscale range.
- ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \
sample a scale from multiple scales.
Args:
img_scale (tuple or list[tuple]): Images scales for resizing.
multiscale_mode (str): Either "range" or "value".
ratio_range (tuple[float]): (min_ratio, max_ratio)
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image.
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
override (bool, optional): Whether to override `scale` and
`scale_factor` so as to call resize twice. Default False. If True,
after the first resizing, the existed `scale` and `scale_factor`
will be ignored so the second resizing can be allowed.
This option is a work-around for multiple times of resize in DETR.
Defaults to False.
"""
def __init__(self,
img_scale=None,
multiscale_mode='range',
ratio_range=None,
keep_ratio=True,
bbox_clip_border=True,
backend='cv2',
override=False):
if img_scale is None:
self.img_scale = None
else:
if isinstance(img_scale, list):
self.img_scale = img_scale
else:
self.img_scale = [img_scale]
assert mmcv.is_list_of(self.img_scale, tuple)
if ratio_range is not None:
# mode 1: given a scale and a range of image ratio
assert len(self.img_scale) == 1
else:
# mode 2: given multiple scales or a range of scales
assert multiscale_mode in ['value', 'range']
self.backend = backend
self.multiscale_mode = multiscale_mode
self.ratio_range = ratio_range
self.keep_ratio = keep_ratio
# TODO: refactor the override option in Resize
self.override = override
self.bbox_clip_border = bbox_clip_border
@staticmethod
def random_select(img_scales):
"""Randomly select an img_scale from given candidates.
Args:
img_scales (list[tuple]): Images scales for selection.
Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \
where ``img_scale`` is the selected image scale and \
``scale_idx`` is the selected index in the given candidates.
"""
assert mmcv.is_list_of(img_scales, tuple)
scale_idx = np.random.randint(len(img_scales))
img_scale = img_scales[scale_idx]
return img_scale, scale_idx
@staticmethod
def random_sample(img_scales):
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
Args:
img_scales (list[tuple]): Images scale range for sampling.
There must be two tuples in img_scales, which specify the lower
and upper bound of image scales.
Returns:
(tuple, None): Returns a tuple ``(img_scale, None)``, where \
``img_scale`` is sampled scale and None is just a placeholder \
to be consistent with :func:`random_select`.
"""
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
img_scale_long = [max(s) for s in img_scales]
img_scale_short = [min(s) for s in img_scales]
long_edge = np.random.randint(
min(img_scale_long),
max(img_scale_long) + 1)
short_edge = np.random.randint(
min(img_scale_short),
max(img_scale_short) + 1)
img_scale = (long_edge, short_edge)
return img_scale, None
@staticmethod
def random_sample_ratio(img_scale, ratio_range):
"""Randomly sample an img_scale when ``ratio_range`` is specified.
A ratio will be randomly sampled from the range specified by
``ratio_range``. Then it would be multiplied with ``img_scale`` to
generate sampled scale.
Args:
img_scale (tuple): Images scale base to multiply with ratio.
ratio_range (tuple[float]): The minimum and maximum ratio to scale
the ``img_scale``.
Returns:
(tuple, None): Returns a tuple ``(scale, None)``, where \
``scale`` is sampled ratio multiplied with ``img_scale`` and \
None is just a placeholder to be consistent with \
:func:`random_select`.
"""
assert isinstance(img_scale, tuple) and len(img_scale) == 2
min_ratio, max_ratio = ratio_range
assert min_ratio <= max_ratio
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
return scale, None
def _random_scale(self, results):
"""Randomly sample an img_scale according to ``ratio_range`` and
``multiscale_mode``.
If ``ratio_range`` is specified, a ratio will be sampled and be
multiplied with ``img_scale``.
If multiple scales are specified by ``img_scale``, a scale will be
sampled according to ``multiscale_mode``.
Otherwise, single scale will be used.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns:
dict: Two new keys 'scale` and 'scale_idx` are added into \
``results``, which would be used by subsequent pipelines.
"""
if self.ratio_range is not None:
scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range)
elif len(self.img_scale) == 1:
scale, scale_idx = self.img_scale[0], 0
elif self.multiscale_mode == 'range':
scale, scale_idx = self.random_sample(self.img_scale)
elif self.multiscale_mode == 'value':
scale, scale_idx = self.random_select(self.img_scale)
else:
raise NotImplementedError
results['scale'] = scale
results['scale_idx'] = scale_idx
def _resize_img(self, results):
"""Resize images with ``results['scale']``."""
for key in results.get('img_fields', ['img']):
if self.keep_ratio:
img, scale_factor = mmcv.imrescale(
results