Source code for magnet.utils.images

import numpy as np
import matplotlib.pyplot as plt

from skimage.transform import resize as imresize
from arghandle import arghandle

[docs]@arghandle def show_images(images, **kwargs): r"""A nifty helper function to show images represented by tensors. Args: images (list or numpy.ndarray or str or torch.Tensor): The images to show * :attr:`images` can be anything which from you could conceivable harvest an image. If it's a :py:class:`torch.Tensor`, it is converted to a :py:class:`numpy.ndarray`. The first dimension of the tensor is treated as a batch dimension. If it's a ``str``, it is treated as a glob path from which all images are extracted. More commonly, a list of numpy arrays can be given. Keyword Arguments: pixel_range (tuple or ``'auto'``): The range of pixel values to be expected. Default: ``'auto'`` cmap (str or None): The color map for the plots. Default: ``'gray'`` merge (bool): If ``True``, all images are merged into one giant image. Default: ``True`` titles (list or None): The titles for each image. Default: ``None`` shape (str): The shape of the merge tile. Default: ``'square'`` resize (str): The common shape to which images are resized. Default: ``'smean'`` retain (bool): If ``True``, the plot is retained. Default: ``False`` savepath (str or None): If given, the image is saved to this path. Default: ``None`` * :attr:`pixel_range` default to the range in the image. * :attr:`titles` should only be given if :attr:`merge` is ``True``. .. note:: The merge shape is controlled by :attr:`shape` which can be either ``'square'``, ``'row'``, ``'column'`` or a ``tuple`` which explicitly specifies this shape. ``'square'`` automatically finds a shape with least difference between the number of rows and columns. This is aesthetically pleasing. In the explicit case, the product of the tuple needs to equal the number of images. """ titles = kwargs.pop('titles', None) pixel_range = kwargs.pop('pixel_range', 'auto') cmap = kwargs.pop('cmap', 'gray') shape = kwargs.pop('shape', 'square') resize = kwargs.pop('resize', 'smean') merge = kwargs.pop('merge', True) retain = kwargs.pop('retain', False) savepath = kwargs.pop('savepath', None) images = _resize(images, resize) if merge: _show_image(_merge(images, shape), titles, cmap, None, pixel_range, retain) return fig, axes = plt.subplots(shape[0], shape[1]) for ax, title, image in zip(axes.flat, titles, images): _show_image(image, title, cmap, ax, pixel_range, retain=True) fig.tight_layout() if not retain: plt.show() if savepath is not None: plt.savefig(Path(savepath), dpi=400, bbox_inches='tight')
def _resize(images, size='smean'): return np.stack([imresize(image, size, order=1, mode='constant', anti_aliasing=False, preserve_range=True) for image in images]) def _merge(images, shape): images = images.reshape((*shape, *images.shape[1:])) for _ in range(2): images = np.concatenate([img for img in images], axis=1) return images @arghandle def _show_image(image, title=None, cmap='gray', ax=None, pixel_range='auto', retain=False): image = (image - pixel_range[0]) * 255 / (pixel_range[1] - pixel_range[0]) ax.imshow(image.astype(np.uint8), cmap) ax.set_xticks([]); ax.set_yticks([]); ax.grid(False) if title is not None: ax.set_title(title) if not retain: plt.show()