"""Functions to interact/analyze with numpy dtypes."""
from __future__ import print_function, division, absolute_import
import numpy as np
import six.moves as sm
import imgaug as ia
KIND_TO_DTYPES = {
"i": ["int8", "int16", "int32", "int64"],
"u": ["uint8", "uint16", "uint32", "uint64"],
"b": ["bool"],
"f": ["float16", "float32", "float64", "float128"]
}
# Added in 0.5.0.
_DTYPE_STR_TO_DTYPES_CACHE = dict()
_UINT8_DTYPE = np.dtype("uint8") # Added in 0.5.0.
_UINT16_DTYPE = np.dtype("uint16") # Added in 0.5.0.
_UINT32_DTYPE = np.dtype("uint32") # Added in 0.5.0.
_UINT64_DTYPE = np.dtype("uint64") # Added in 0.5.0.
_INT8_DTYPE = np.dtype("int8") # Added in 0.5.0.
_INT16_DTYPE = np.dtype("int16") # Added in 0.5.0.
_INT32_DTYPE = np.dtype("int32") # Added in 0.5.0.
_INT64_DTYPE = np.dtype("int64") # Added in 0.5.0.
_FLOAT16_DTYPE = np.dtype("float16") # Added in 0.5.0.
_FLOAT32_DTYPE = np.dtype("float32") # Added in 0.5.0.
_FLOAT64_DTYPE = np.dtype("float64") # Added in 0.5.0.
_BOOL_DTYPE = np.dtype("bool") # Added in 0.5.0.
# Added in 0.5.0.
try:
_FLOAT128_DTYPE = np.dtype("float128")
except TypeError:
_FLOAT128_DTYPE = None
# Added in 0.5.0.
_DTYPE_NAME_TO_DTYPE = {
"uint8": _UINT8_DTYPE,
"uint16": _UINT16_DTYPE,
"uint32": _UINT32_DTYPE,
"uint64": _UINT64_DTYPE,
"int8": _INT8_DTYPE,
"int16": _INT16_DTYPE,
"int32": _INT32_DTYPE,
"int64": _INT64_DTYPE,
"float16": _FLOAT16_DTYPE,
"float32": _FLOAT32_DTYPE,
"float64": _FLOAT64_DTYPE,
"float128": _FLOAT128_DTYPE,
"bool": _BOOL_DTYPE
}
[docs]
def normalize_dtypes(dtypes):
if not isinstance(dtypes, list):
return [normalize_dtype(dtypes)]
return [normalize_dtype(dtype) for dtype in dtypes]
[docs]
def normalize_dtype(dtype):
assert not isinstance(dtype, list), (
"Expected a single dtype-like, got a list instead.")
return (
dtype.dtype
if ia.is_np_array(dtype) or ia.is_np_scalar(dtype)
else np.dtype(dtype)
)
[docs]
def change_dtype_(arr, dtype, clip=True, round=True):
# pylint: disable=redefined-builtin
assert ia.is_np_array(arr), (
"Expected array as input, got type %s." % (type(arr),))
dtype = normalize_dtype(dtype)
if arr.dtype.name == dtype.name:
return arr
if round and arr.dtype.kind == "f" and dtype.kind in ["u", "i", "b"]:
arr = np.round(arr)
if clip:
min_value, _, max_value = get_value_range_of_dtype(dtype)
arr = clip_(arr, min_value, max_value)
return arr.astype(dtype, copy=False)
[docs]
def change_dtypes_(images, dtypes, clip=True, round=True):
# pylint: disable=redefined-builtin
if ia.is_np_array(images):
if ia.is_iterable(dtypes):
dtypes = normalize_dtypes(dtypes)
n_distinct_dtypes = len({dt.name for dt in dtypes})
assert len(dtypes) == len(images), (
"If an iterable of dtypes is provided to "
"change_dtypes_(), it must contain as many dtypes as "
"there are images. Got %d dtypes and %d images." % (
len(dtypes), len(images))
)
assert n_distinct_dtypes == 1, (
"If an image array is provided to change_dtypes_(), the "
"provided 'dtypes' argument must either be a single dtype "
"or an iterable of N times the *same* dtype for N images. "
"Got %d distinct dtypes." % (n_distinct_dtypes,)
)
dtype = dtypes[0]
else:
dtype = normalize_dtype(dtypes)
result = change_dtype_(images, dtype, clip=clip, round=round)
elif ia.is_iterable(images):
dtypes = (
[normalize_dtype(dtypes)] * len(images)
if not isinstance(dtypes, list)
else normalize_dtypes(dtypes)
)
assert len(images) == len(dtypes), (
"Expected the provided images and dtypes to match, but got "
"iterables of size %d (images) %d (dtypes)." % (
len(images), len(dtypes)))
result = images
for i, (image, dtype) in enumerate(zip(images, dtypes)):
assert ia.is_np_array(image), (
"Expected each image to be an ndarray, got type %s "
"instead." % (type(image),))
result[i] = change_dtype_(image, dtype, clip=clip, round=round)
else:
raise Exception("Expected numpy array or iterable of numpy arrays, "
"got type '%s'." % (type(images),))
return result
# TODO replace this everywhere in the library with change_dtypes_
# TODO mark as deprecated
[docs]
def restore_dtypes_(images, dtypes, clip=True, round=True):
# pylint: disable=redefined-builtin
return change_dtypes_(images, dtypes, clip=clip, round=round)
[docs]
def copy_dtypes_for_restore(images, force_list=False):
if ia.is_np_array(images):
if force_list:
return [images.dtype for _ in sm.xrange(len(images))]
return images.dtype
return [image.dtype for image in images]
[docs]
def increase_itemsize_of_dtype(dtype, factor):
dtype = normalize_dtype(dtype)
assert ia.is_single_integer(factor), (
"Expected 'factor' to be an integer, got type %s instead." % (
type(factor),))
# int8 -> int64 = factor 8
# uint8 -> uint64 = factor 8
# float16 -> float128 = factor 8
assert factor in [1, 2, 4, 8], (
"The itemsize may only be increased any of the following factors: "
"1, 2, 4 or 8. Got factor %d." % (factor,))
assert dtype.kind != "b", "Cannot increase the itemsize of boolean."
dt_high_name = "%s%d" % (dtype.kind, dtype.itemsize * factor)
try:
dt_high = np.dtype(dt_high_name)
return dt_high
except TypeError:
raise TypeError(
"Unable to create a numpy dtype matching the name '%s'. "
"This error was caused when trying to find a dtype "
"that increases the itemsize of dtype '%s' by a factor of %d."
"This error can be avoided by choosing arrays with lower "
"resolution dtypes as inputs, e.g. by reducing "
"float32 to float16." % (
dt_high_name,
dtype.name,
factor
)
)
[docs]
def get_minimal_dtype(arrays, increase_itemsize_factor=1):
assert isinstance(arrays, list), (
"Expected a list of arrays or dtypes, got type %s." % (type(arrays),))
assert len(arrays) > 0, (
"Cannot estimate minimal dtype of an empty iterable.")
input_dts = normalize_dtypes(arrays)
# This loop construct handles (1) list of a single dtype, (2) list of two
# dtypes and (3) list of 3+ dtypes. Note that promote_dtypes() always
# expects exactly two dtypes.
promoted_dt = input_dts[0]
input_dts = input_dts[1:]
while len(input_dts) >= 1:
promoted_dt = np.promote_types(promoted_dt, input_dts[0])
input_dts = input_dts[1:]
if increase_itemsize_factor > 1:
assert isinstance(promoted_dt, np.dtype), (
"Expected numpy.dtype output from numpy.promote_dtypes, got type "
"%s." % (type(promoted_dt),))
return increase_itemsize_of_dtype(promoted_dt,
increase_itemsize_factor)
return promoted_dt
# TODO rename to: promote_arrays_to_minimal_dtype_
[docs]
def increase_array_resolutions_(arrays, factor):
dts = normalize_dtypes(arrays)
dts = [increase_itemsize_of_dtype(dt, factor) for dt in dts]
return change_dtypes_(arrays, dts, round=False, clip=False)
[docs]
def get_value_range_of_dtype(dtype):
dtype = normalize_dtype(dtype)
if dtype.kind == "f":
finfo = np.finfo(dtype)
return finfo.min, 0.0, finfo.max
if dtype.kind == "u":
iinfo = np.iinfo(dtype)
return iinfo.min, iinfo.min + 0.5 * iinfo.max, iinfo.max
if dtype.kind == "i":
iinfo = np.iinfo(dtype)
return iinfo.min, -0.5, iinfo.max
if dtype.kind == "b":
return 0, None, 1
raise Exception("Cannot estimate value range of dtype '%s' "
"(type: %s)" % (str(dtype), type(dtype)))
# TODO call this function wherever data is clipped
[docs]
def clip_(array, min_value, max_value):
# uint64 is disallowed, because numpy's clip seems to convert it to float64
# int64 is disallowed, because numpy's clip converts it to float64 since
# 1.17
# TODO find the cause for that
gate_dtypes_strs(
{array.dtype},
allowed="bool uint8 uint16 uint32 int8 int16 int32 "
"float16 float32 float64 float128",
disallowed="uint64 int64"
)
# If the min of the input value range is above the allowed min, we do not
# have to clip to the allowed min as we cannot exceed it anyways.
# Analogous for max. In fact, we must not clip then to min/max as that can
# lead to errors in numpy's clip. E.g.
# >>> arr = np.zeros((1,), dtype=np.int32)
# >>> np.clip(arr, 0, np.iinfo(np.dtype("uint32")).max)
# will return
# array([-1], dtype=int32)
# (observed on numpy version 1.15.2).
min_value_arrdt, _, max_value_arrdt = get_value_range_of_dtype(array.dtype)
if min_value is not None and min_value < min_value_arrdt:
min_value = None
if max_value is not None and max_value_arrdt < max_value:
max_value = None
if min_value is not None or max_value is not None:
# for scalar arrays, i.e. with shape = (), "out" is not a valid
# argument
if len(array.shape) == 0:
array = np.clip(array, min_value, max_value)
elif array.dtype.name == "int32":
# Since 1.17 (before maybe too?), numpy.clip() turns int32
# to float64. float64 should cover the whole value range of int32,
# so the dtype is not rejected here.
# TODO Verify this. Is rounding needed before conversion?
array = np.clip(array, min_value, max_value).astype(array.dtype)
else:
array = np.clip(array, min_value, max_value, out=array)
return array
[docs]
def clip_to_dtype_value_range_(array, dtype, validate=True,
validate_values=None):
dtype = normalize_dtype(dtype)
min_value, _, max_value = get_value_range_of_dtype(dtype)
if validate:
array_val = array
if ia.is_single_integer(validate):
assert validate >= 1, (
"If 'validate' is an integer, it must have a value >=1, "
"got %d instead." % (validate,))
assert validate_values is None, (
"If 'validate' is an integer, 'validate_values' must be "
"None. Got type %s instead." % (type(validate_values),))
array_val = array.flat[0:validate]
if validate_values is not None:
min_value_found, max_value_found = validate_values
else:
min_value_found = np.min(array_val)
max_value_found = np.max(array_val)
assert min_value <= min_value_found <= max_value, (
"Minimum value of array is outside of allowed value range (%.4f "
"vs %.4f to %.4f)." % (min_value_found, min_value, max_value))
assert min_value <= max_value_found <= max_value, (
"Maximum value of array is outside of allowed value range (%.4f "
"vs %.4f to %.4f)." % (max_value_found, min_value, max_value))
return clip_(array, min_value, max_value)
[docs]
def gate_dtypes_strs(dtypes, allowed, disallowed, augmenter=None):
"""Verify that input dtypes match allowed/disallowed dtype strings.
Added in 0.5.0.
Parameters
----------
dtypes : numpy.ndarray or iterable of numpy.ndarray or iterable of numpy.dtype
One or more input dtypes to verify.
allowed : str
Names of one or more allowed dtypes, separated by single spaces.
disallowed : str
Names of disallowed dtypes, separated by single spaces.
Must not intersect with allowed dtypes.
augmenter : None or imgaug.augmenters.meta.Augmenter, optional
If the gating happens for an augmenter, it should be provided
here. This information will be used to improve output error
messages and warnings.
"""
allowed, disallowed = _convert_gate_dtype_strs_to_types(
allowed, disallowed
)
return _gate_dtypes(dtypes, allowed, disallowed, augmenter=augmenter)
# Added in 0.5.0.
def _convert_gate_dtype_strs_to_types(allowed, disallowed):
allowed_types = _convert_dtype_strs_to_types(allowed)
disallowed_types = _convert_dtype_strs_to_types(disallowed)
intersection = allowed_types.intersection(disallowed_types)
nb_overlapping = len(intersection)
assert nb_overlapping == 0, (
"Expected 'allowed' and 'disallowed' dtypes to not contain the same "
"dtypes, but %d appeared in both arguments. Got allowed: %s, "
"disallowed: %s, intersection: %s" % (
nb_overlapping,
allowed,
disallowed,
intersection
)
)
return allowed_types, disallowed_types
# Added in 0.5.0.
def _convert_dtype_strs_to_types_cached(dtypes):
dtypes_parsed = _DTYPE_STR_TO_DTYPES_CACHE.get(dtypes, None)
if dtypes_parsed is None:
dtypes_parsed = _convert_dtype_strs_to_types_cached(dtypes)
_DTYPE_STR_TO_DTYPES_CACHE[dtypes] = dtypes_parsed
return dtypes_parsed
# Added in 0.5.0.
def _convert_dtype_strs_to_types(dtypes):
result = set()
for name in dtypes.split(" "):
name = name.strip()
if name:
dtype = _DTYPE_NAME_TO_DTYPE[name]
# this if ignores float128 if it is not available on the user
# system
if dtype is not None:
result.add(dtype)
return result
# Deprecated since 0.5.0.
[docs]
@ia.deprecated("imgaug.dtypes.gate_dtypes_strs")
def gate_dtypes(dtypes, allowed, disallowed, augmenter=None):
def _cvt(dts):
normalized = set()
if not isinstance(dts, list):
dts = [dts]
for dtype in dts:
try:
dtype = normalize_dtype(dtype)
normalized.add(dtype)
except TypeError:
pass
return normalized
dtypes_norm = _cvt(dtypes)
allowed_norm = _cvt(allowed)
disallowed_norm = _cvt(disallowed)
return _gate_dtypes(
dtypes_norm, allowed_norm, disallowed_norm, augmenter=augmenter
)
def _gate_dtypes(dtypes, allowed, disallowed, augmenter=None):
"""Verify that input dtypes are among allowed and not disallowed dtypes.
Added in 0.5.0.
Parameters
----------
dtypes : numpy.ndarray or iterable of numpy.ndarray or iterable of numpy.dtype
One or more input dtypes to verify.
Must not be a dtype function (like ``np.int64``), only a proper
dtype (like ``np.dtype("int64")``). For performance reasons this is
not validated.
allowed : set of numpy.dtype
One or more allowed dtypes.
disallowed : None or set of numpy.dtype
Any number of disallowed dtypes. Should not intersect with allowed
dtypes.
augmenter : None or imgaug.augmenters.meta.Augmenter, optional
If the gating happens for an augmenter, it should be provided
here. This information will be used to improve output error
messages and warnings.
"""
if isinstance(dtypes, np.ndarray) or ia.is_np_scalar(dtypes):
dtypes = set([dtypes.dtype])
elif isinstance(dtypes, list):
dtypes = {arr.dtype for arr in dtypes}
dts_not_explicitly_allowed = dtypes - allowed
all_allowed = (not dts_not_explicitly_allowed)
if all_allowed:
return
if disallowed is None:
disallowed = set()
dts_explicitly_disallowed = dts_not_explicitly_allowed.intersection(
disallowed
)
dts_undefined = dts_not_explicitly_allowed - disallowed
if dts_explicitly_disallowed:
for dtype in dts_explicitly_disallowed:
if augmenter is None:
raise ValueError(
"Got dtype '%s', which is a forbidden dtype (%s)." % (
np.dtype(dtype).name,
_dtype_names_to_string(disallowed)
))
raise ValueError(
"Got dtype '%s' in augmenter '%s' (class '%s'), which "
"is a forbidden dtype (%s)." % (
np.dtype(dtype).name,
augmenter.name,
augmenter.__class__.__name__,
_dtype_names_to_string(disallowed),
))
if dts_undefined:
for dtype in dts_undefined:
if augmenter is None:
ia.warn(
"Got dtype '%s', which was neither explicitly allowed "
"(%s), nor explicitly disallowed (%s). Generated "
"outputs may contain errors." % (
dtype.name,
_dtype_names_to_string(allowed),
_dtype_names_to_string(disallowed),
)
)
else:
ia.warn(
"Got dtype '%s' in augmenter '%s' (class '%s'), which was "
"neither explicitly allowed (%s), nor explicitly "
"disallowed (%s). Generated outputs may contain "
"errors." % (
dtype.name,
augmenter.name,
augmenter.__class__.__name__,
_dtype_names_to_string(allowed),
_dtype_names_to_string(disallowed),
)
)
# Added in 0.5.0.
def _dtype_names_to_string(dtypes):
if isinstance(dtypes, set):
dtypes = list(sorted(dtypes))
return ", ".join([np.dtype(dt).name for dt in dtypes])
[docs]
def allow_only_uint8(dtypes, augmenter=None):
"""Verify that input dtypes are uint8.
Added in 0.5.0.
Parameters
----------
dtypes : numpy.ndarray or iterable of numpy.ndarray or iterable of numpy.dtype
One or more input dtypes to verify.
augmenter : None or imgaug.augmenters.meta.Augmenter, optional
If the gating happens for an augmenter, it should be provided
here. This information will be used to improve output error
messages and warnings.
"""
return gate_dtypes_strs(
dtypes,
allowed="uint8",
disallowed="uint16 uint32 uint64 "
"int8 int16 int32 int64 "
"float16 float32 float64 float128 "
"bool",
augmenter=augmenter
)