Source code for simple_sr.utils.image.image_utils

import os
import tensorflow as tf
from PIL import Image, ImageDraw, ImageFont


[docs]def read_img(fpath, normalize_func=None, yield_path=False): """ Read image from supplied path into `tensorflow.image` tensor. :param fpath: File path to image file. :param normalize_func: | Normalization function to apply, None means no normalization takes place. | Functions to either normalize to [0, 1] or [-1, 1] are in `simple_sr.utils.image.image_transforms`. :param yield_path: Whether image file path should be returned as well. :return: `tensorflow.image` tensor and optionally the path of the image as well. """ img = tf.io.read_file(fpath) img = tf.image.decode_png(img) img = tf.cast(img, tf.float32) if normalize_func is not None: img = normalize_func(img) if yield_path: return img, fpath else: return img
[docs]def tensor_to_img(tensor): """ Converts a tensor to `PIL.Image` object. :param tensor: Tensor to be converted. :return: `PIL.Image` from input tensor. """ if tensor.shape.rank == 4 and tensor.shape[0] == 1: tensor = _extract_tensor(tensor) return tf.keras.preprocessing.image.array_to_img(tensor)
[docs]def reconstruct_from_overlapping_patches(patches, image_height, image_width, pixel_overlap, horizontal_padding, vertical_padding): """ Reconstructs an image from supplied overlapping patches. :param patches: Tensor of rank 4 containing the patches that will be stitched together. :param image_height: Height of the resulting image. :param image_width: Width of the resulting image. :param pixel_overlap: Pixel overlap per patch per direction (north, east, south, west) in number of pixels. :param horizontal_padding: Number of pixel rows that were appended to the bottom of the image before extraction of patches. :param vertical_padding: Number of pixel columns that were appended on the right side of the image before extraction of patches. :return: A Tensor of rank 3 contatining the reconstructed image. """ if patches.shape.rank != 4: raise ValueError("Tensor with patches needs to be of rank 4") _patches = patches[:, pixel_overlap: -pixel_overlap, pixel_overlap: -pixel_overlap, :] return _reconstruct( _patches, image_height, image_width, image_height + horizontal_padding, image_width + vertical_padding)
[docs]def reconstruct_from_patches(patches, original_height, original_width, horizontal_padding=0, vertical_padding=0): """ Reconstructs single image from supplied non-overlapping patches. :param patches: Tensor of rank 4 containing patches. :param original_height: Height of image to be reconstructed. :param original_width: Width of image to be reconstructed. :param vertical_padding: Number of columns that were appended to the image before extraction of patches. :param horizontal_padding: Number of rows that were appended to the image before extraction of patches. :return: Tensor of rank 3 containing the combined patches. """ if patches.shape.rank != 4: raise ValueError("Tensor with patches needs to be of rank 4") if horizontal_padding < 0 or vertical_padding < 0: raise ValueError("Padding can't be negative") return _reconstruct(patches, original_height, original_width, original_height + horizontal_padding, original_width + vertical_padding)
[docs]def segment_into_patches(tensor, patch_width=32, patch_height=32, pixel_overlap=0): """ Segments input tensor into patches for memory efficient inference on large pictures. .. note: For a typical use-case of: 1. segment image into patches 2. upscale patches 3. reconstruct full image from upscaled patches you should consider using `pixel_overlap` > 0 for better results. :param tensor: Tensor to segment into patches, needs to be rank 3. :param patch_width: Width of extracted patches. :param patch_height: Height of extracted patches. :param pixel_overlap: Amount of overlap per patch per direction (north, east, south, west) in pixels :return: A tuple containing tensor of rank 4 with patches of input tensor, and a tensor of shape (2, 2) containing the padding that was used/needed. Structure of padding: :code:`[[number of padded rows top, number of padded rows bottom],` :code:`[number of padded columns left, number of padded columns right]]` """ if tensor.shape.rank != 3 and (tensor.shape.rank == 4 and tensor.shape[0] != 1): raise ValueError("Tensor must be of rank 3") _tensor = tensor if tensor.shape.rank == 4: _tensor = _extract_tensor(tensor) if _tensor.shape[0] < patch_height or _tensor.shape[1] < patch_width: raise ValueError("Patch dimensions are larger than image size") if pixel_overlap != 0: return _segment_with_overlap(_tensor, patch_width, patch_height, pixel_overlap=pixel_overlap) else: return _segment(_tensor, patch_width, patch_height)
def _segment_with_overlap(tensor, patch_width, patch_height, pixel_overlap=5): # pad to a multiple of patch_width and patch_height and additionally add #{pixel overlap} pixels horizontal_padding = [pixel_overlap, pixel_overlap] vertical_padding = [pixel_overlap, pixel_overlap] if tensor.shape[0] % patch_height != 0: horizontal_padding[1] += (patch_height - tensor.shape[0]) % patch_height if tensor.shape[1] % patch_width != 0: vertical_padding[1] += (patch_width - tensor.shape[1]) % patch_width padding = [horizontal_padding, vertical_padding, [0, 0]] padded = tf.pad( tensor, mode="CONSTANT", paddings=padding, constant_values=0 ) patches = list() for row in range(pixel_overlap, padded.shape[0] - pixel_overlap, patch_width): for col in range(pixel_overlap, padded.shape[1] - pixel_overlap, patch_height): x_start = col - pixel_overlap x_end = col + patch_width + pixel_overlap y_start = row - pixel_overlap y_end = row + patch_height + pixel_overlap patches.append( padded[y_start: y_end, x_start: x_end, :] ) return tf.convert_to_tensor(patches), padding[:-1] def _segment(tensor, patch_width, patch_height): horizontal_padding = [0, 0] vertical_padding = [0, 0] if tensor.shape[0] % patch_height != 0: horizontal_padding = [0, (patch_height - tensor.shape[0]) % patch_height] if tensor.shape[1] % patch_width != 0: vertical_padding = [0, (patch_width - tensor.shape[1]) % patch_width] padding = [horizontal_padding, vertical_padding] segments = tf.space_to_batch([tensor], [patch_height, patch_width], padding) segments = tf.split(segments, patch_height * patch_width, 0) segments = tf.stack(segments, 3) segments = tf.reshape(segments, [-1, patch_height, patch_width, 3]) return segments, padding def _reconstruct(patches, original_height, original_width, padded_height, padded_width): patch_height = patches.shape[1] patch_width = patches.shape[2] patch_channels = patches.shape[3] reconstructed = tf.reshape( patches, [1, (padded_height//patch_height), (padded_width//patch_width), patch_height * patch_width, patch_channels] ) reconstructed = tf.split(reconstructed, patch_height * patch_width, 3) reconstructed = tf.stack(reconstructed, 0) reconstructed = tf.reshape( reconstructed, [patch_height * patch_width, (padded_height//patch_height), (padded_width//patch_width), patch_channels] ) reconstructed = tf.batch_to_space(reconstructed, [patch_height, patch_width], [[0, 0], [0, 0]])[0] return reconstructed[0: original_height, 0: original_width, :]
[docs]def save_single(tensor, save_dir, fname, label=None): """ Converts tensors of rank 3 or rank 4 to `PIL` image objects and saves them to disk. :param tensor: Tensor containing the images to be saved. :param save_dir: Folder to save images to. :param fname: Name of files to save images in (will be suffixed with position in tensor if multiple images are saved) :param label: Optional labels, will be shown in the bottom left of the saved image. """ if tensor.shape.rank < 3 or tensor.shape.rank > 4: raise ValueError("Tensor must be of rank 3 or rank 4") if tensor.shape.rank == 3: _save_as_img(tensor, save_dir, fname, label) else: for idx, t in enumerate(tensor): _save_as_img(t, save_dir, f"{fname}_{idx}", label)
def _save_as_img(tensor, save_dir, fname, label=None): img = tensor_to_img(tensor) if label is not None: _annotate_img(img, label, (0, 255, 0)) os.makedirs(save_dir, exist_ok=True) img.save(f"{save_dir}/{fname}.png")
[docs]def combine_halfs(left_tensor, right_tensor, left_label, save_dir, fname, right_label="interpolated", grid=False): """ Creates a combined image from two images by stitching together left and right half respectively of each supplied image. A typical use case would be comparing an image upscaled with Super-Resolution to an image upscaled with interpolation. :param left_tensor: | Tensor containing elements that will be on the left side of the resulting image. | Tensor may be of rank 3 or rank 4. :param right_tensor: | Tensor containing elements that will be on the right side of the resulting image. | Tensor may be of rank 3 or rank 4, must have same dimensions as `left_tensor`. :param left_label: String to labels left side of resulting image. :param right_label: String to labels right side of resulting image. :param save_dir: Path to save resulting image to. :param fname: File name to save resulting image under. :param grid: | Option to additionally plot all stitched images together in a grid. | `left_tensor` and `right_tensor` must be of rank 4 to use this option. """ if left_tensor.shape[0] != right_tensor.shape[0]: raise ValueError("number of sr and lr images does not match") if grid and (left_tensor.shape[0] % 2 != 0 or left_tensor.shape[0] < 4): raise ValueError("can only prepare image grid for an even number of at least 4 images") imgs = list() for idx, (sr, lr) in enumerate(zip(left_tensor, right_tensor)): sr_img = tensor_to_img(sr) _annotate_img(sr_img, left_label, (0, 255, 0)) lr_img = tensor_to_img(lr) lr_img = lr_img.resize(sr_img.size) _annotate_img(lr_img, right_label, (255, 0, 0), loc="right") main_img = Image.new("RGB", sr_img.size, (255, 255, 255)) half = sr_img.width // 2 end = sr_img.width bottom = sr_img.height sr_img = sr_img.crop((0, 0, half, bottom)) lr_img = lr_img.crop((half, 0, end, bottom)) main_img.paste(sr_img, (0, 0)) main_img.paste(lr_img, (half, 0)) draw = ImageDraw.Draw(main_img) draw.line((half, 0, half, bottom), fill=128) imgs.append(main_img) os.makedirs(save_dir, exist_ok=True) main_img.save(f"{save_dir}/{fname}_{idx}.png") if grid: num_rows = len(imgs) // 4 num_cols = 4 total_width = 0 total_height = 0 for img in imgs: total_width += img.width total_height += img.height grid_width = total_width // num_rows grid_height = total_height // num_cols grid = Image.new("RGB", (grid_width, grid_height), (255, 255, 255)) x_loc = 0 y_loc = 0 for idx, img in enumerate(imgs): if idx == num_cols: x_loc = 0 y_loc = grid.height // 2 grid.paste(img, (x_loc, y_loc)) x_loc += img.width os.makedirs(save_dir, exist_ok=True) grid.save(f"{save_dir}/{fname}_grid.png")
[docs]def prepare_image_grid(save_dir, fname, low_res_key=None, original=None, psnr=None, ssim=None, **kwargs): """ Prepares and saves an image grid for comparison. Supplied images must have the same dimensions. :param save_dir: Path to save image grid to. :param fname: File name to save image grid in. :param original: Optionally an original image can be plotted next to the image grid (for instance if image in grid are patches from some image). :param psnr: Optional dictionary with keys corresponding to keys in :code:`kwargs` and values containing tensors of PSNR values. PSNR values will be annotated in corresponding pictures. :param ssim: Optional dictionary with keys corresponding to keys in :code:`kwargs` and values containing tensors of SSIM values. SSIM values will be annotated in corresponding pictures. :param low_res_key: Optional key for low-resolution images in :code:`kwargs`. Images corresponding to this key, will be padded and centered to align with larges images in grid. :param kwargs: A dictionary containing image tensors as values. Keys will be used as labels strings inside plotted images. """ num_imgs = -1 for tensor in kwargs.values(): if num_imgs == -1: num_imgs = tensor.shape[0] elif tensor.shape[0] != num_imgs: raise ValueError("received differing amount of images per supplied model - can't produce grid") if psnr is not None: _verfify_supplied_metrics(psnr, kwargs) if ssim is not None: _verfify_supplied_metrics(ssim, kwargs) num_rows = len(kwargs) num_cols = 0 max_height = -1 max_width = -1 for label, tensors in kwargs.items(): if label not in ["hr", "ground truth"]: num_cols = max(num_cols, tensors.shape[-4]) max_height = max(max_height, tensors.shape[-3]) max_width = max(max_width, tensors.shape[-2]) # TODO: document magic keys in kwargs try: kwargs["ground truth"] = tf.image.resize(kwargs["ground truth"], size=(max_height, max_width), method=tf.image.ResizeMethod.BICUBIC) except KeyError: pass if num_cols == 1: # plot images next to each other grid_width = num_rows * max_width grid_height = num_cols * max_height grid = Image.new("RGB", (grid_width, grid_height), (255, 255, 255)) _labels = list() _pics = list() _psnr = list() _ssim = list() for idx, (label, tensor) in enumerate(kwargs.items()): _labels.append(label) _t = tensor if _t.shape.rank == 4: _t = tf.reshape(_t, (_t.shape[1:])) if label == low_res_key: _t = _pad_image(_t, height=max_height, width=max_width) _pics.append(_t) _psnr.append(psnr[label] if psnr is not None else None) _ssim.append(ssim[label] if ssim is not None else None) _prepare_img_row( _pics, grid, _labels, (0, 255, 0), y_loc=0, resize=False, resize_dims=None, psnr_values=_psnr, ssim_values=_ssim ) else: grid_width = num_cols * max_width grid_height = num_rows * max_height # account for label on left side if no original is supplied if original is None: column_label_width = int(grid_width * 0.05) else: column_label_width = 0 grid_width += column_label_width grid = Image.new("RGB", (grid_width, grid_height), (255, 255, 255)) y_location = 0 for idx, (label, tensors) in enumerate(kwargs.items()): try: _psnr = psnr[label] except (TypeError, KeyError): _psnr = None try: _ssim = ssim[label] except (TypeError, KeyError): _ssim = None if original is None: _annotate_column( grid, label, (0, 255, 0), column_label_width, max_height, ypos=max_height * idx ) labels = None else: labels = label _t = tensors if label == low_res_key: _t = _pad_image(_t, height=max_height, width=max_width) _prepare_img_row( _t, grid, labels=labels, color=(0, 255, 0), y_loc=y_location, resize=False, psnr_values=_psnr, x_axis_offset=column_label_width, ssim_values=_ssim ) y_location += max_height if original is not None: try: origin = Image.open(original) except AttributeError: origin = original origin_aspect_ratio = origin.width / origin.height original_height = grid.height original_width = int(origin_aspect_ratio * original_height) original = origin.resize((original_width, original_height)) _annotate_img(original, "original", (255, 0, 255)) combined_width = grid_width + original_width combined_img = Image.new("RGB", (combined_width, grid.height), (255, 255, 255)) combined_img.paste(original, (0, 0)) combined_img.paste(grid, (original_width, 0)) os.makedirs(save_dir, exist_ok=True) combined_img.save(f"{save_dir}/{fname}.png") else: os.makedirs(save_dir, exist_ok=True) grid.save(f"{save_dir}/{fname}.png")
def _pad_image(tensor, height, width): if tensor.shape.rank > 4 or tensor.shape.rank < 3: raise ValueError("tensor must be of rank 3 or rank 4") _tensor = tensor if tensor.shape.rank == 3: _tensor = tf.reshape(tensor, (1, *tensor.shape)) horz_pad = (height - tensor.shape[-3]) // 2 vert_pad = (width - tensor.shape[-2]) // 2 for idx, t in enumerate(_tensor): _t = tf.pad(t, [[horz_pad, horz_pad], [vert_pad, vert_pad], [0, 0]]) # resize additionally to account for integer calculation differences _t = tf.image.resize(_t, size=(height, width)) _t = tf.reshape(_t, (1, *_t.shape)) if idx == 0: res = _t else: res = tf.concat([res, _t], axis=0) if tensor.shape.rank == 3: return res[0] return res def _verfify_supplied_metrics(metrics_dict, img_dict): if len(metrics_dict) != len(img_dict.values()): raise ValueError("did not receive psnr values for every supplied model result") num_psnr_vals = -1 for psnr_val in metrics_dict.values(): if num_psnr_vals == -1: num_psnr_vals = psnr_val.shape[0] elif psnr_val.shape[0] != num_psnr_vals: raise ValueError("count of supplied psnr values does not match count of supplied images") def _annotate_img(img, text, color, loc=None): draw = ImageDraw.Draw(img) font = _load_font(font_size=(int(max(6, (16 - (1024//img.width)))))) width, height = font.getsize(text) if loc is None: loc = (5, img.size[1] - (5 + height)) elif loc == "right": loc = (img.width - (width + 5), img.height - (5 + height)) elif loc == "ssim": loc = (img.width - (width + 5), img.height - (2 * (5 + height))) draw.rectangle((*loc, loc[0] + width, loc[1] + height), fill="black") draw.text(loc, text, font=font, fill=color) def _annotate_column(img, text, color, width, height, ypos, xpos=0): # draw/annotate horizontal first in tmp image, then rotate and paste to original tmp_img = Image.new("RGB", (height, width), (0, 0, 0)) draw = ImageDraw.Draw(tmp_img) font = _load_font(font_size=(int(max(6, (16 - (1024//img.width)))))) font_width, font_height = font.getsize(text) draw.text((5, width - (5 + font_height)), text, font=font, fill=color) rot = tmp_img.rotate(90, expand=1) img.paste(rot, (xpos, ypos)) def _prepare_img_row(tensors, main_img, labels, color, y_loc, resize=False, resize_dims=None, psnr_values=None, ssim_values=None, x_axis_offset=0): _labels = labels if type(labels) is not list: _labels = [labels] * tensors.shape[0] for idx, tensor in enumerate(tensors): img = tensor_to_img(tensor) if resize: img = img.resize(resize_dims) if _labels[idx] is not None: _annotate_img(img, _labels[idx], color) if psnr_values is not None and psnr_values[idx] is not None: psnr_value = _extract_metric(psnr_values, idx) _annotate_img(img, f"psnr: {psnr_value}", (255, 0, 0), loc="right") if ssim_values is not None and ssim_values[idx] is not None: ssim_value = _extract_metric(ssim_values, idx) _annotate_img(img, f"ssim: {ssim_value}", (255, 0, 0), loc="ssim") main_img.paste(img, (x_axis_offset + img.size[0] * idx, y_loc)) def _extract_metric(metric_values, idx): if metric_values[idx].numpy() == float("inf"): value = u"\u221E" elif metric_values[idx].numpy() == -1: value = "N/A" else: try: value = f"{metric_values[idx]:.2f}" except TypeError: value = f"{metric_values[idx][0]:.2f}" return value def _load_font(font_size=10): try: font = ImageFont.truetype("./resources/NotoSansMono-Bold.ttf", size=font_size) except OSError: print("cannot locate font, using default font as fallback") font = ImageFont.load_default() return font def _extract_tensor(batch): t = None for tensor in batch: t = tensor return t if __name__ == "__main__": pass