Source code for simple_sr.operations.evaluation

import sys
import os
import logging
from pathlib import Path
import tensorflow as tf

from simple_sr.data_pipeline.data_pipeline import DataPipeline
from simple_sr.utils.image import image_transforms, metrics, image_utils
from simple_sr.utils import logger

log = logging.getLogger(logger.RESULTS_LOGGER)


[docs]def evaluate_on_validationdata(config, model_name="", pipeline=None, model=None, save_grid=False, combine_halfs=False, save_single=False, save_prefix="", calc_stats=False): """ Evaluates supplied model on evaluation set of supplied data-pipeline and plots resulting distributions in regards to PSNR and SSIM. :param config: | Initialized object of type :code:`ConfigUtil`. | Will be used for save path of resulting plots and to load model/pipeline in case they are not supplied. :param model_name: Name of model, will be used as label for plots. :param pipeline: | Initialized object of type :code:`simple_sr.DataPipeline`. | Will be initialized from :code:`config` if None. :param model: Object of type `Keras.model`, will be loaded from model path in config if None. :param save_grid: Whether so save an image grid. (useful if image grid contains patches) :param combine_halfs: Whether to save an image that's upscaled with supplied models on one half and bicubic on the other half. :param save_single: Whether to save single images. :param save_prefix: Prefix for filenames. :param calc_stats: Whether to calculate psnr and ssim. """ if pipeline is None: pipeline = DataPipeline.from_config(config) if model is None: if config.model_path is None: raise ValueError("No model was supplied and config does not contain path to model") models = {Path(path).stem: _load_model(path) for path in config.model_path} else: models = model if type(model) is dict else {model_name: model} validation_batch_generator = pipeline.validation_batch_generator() ground_truth_key = "GT" low_res_key = "LR" interpolated_key = pipeline.resize_filter psnr_y_key = "psnr-y" result_pics = dict() metrics_res = dict() for model_name in models.keys(): metrics_res[model_name] = { "psnr": tf.constant([], dtype=tf.float32), psnr_y_key: tf.constant([], dtype=tf.float32), "ssim": tf.constant([], dtype=tf.float32) } metrics_res[interpolated_key] = { "psnr": tf.constant([], dtype=tf.float32), psnr_y_key: tf.constant([], dtype=tf.float32), "ssim": tf.constant([], dtype=tf.float32) } for idx, (lr_batch, hr_batch) in enumerate(validation_batch_generator): # hr_batch, sr_batch are normalized to [-1, 1], whereas lr_batch is normalized to [0, 1] # -> need to adjust lr_batch for comparison result_pics[ground_truth_key] = hr_batch result_pics[low_res_key] = lr_batch _lr_batch = lr_batch * 255 _lr_batch = image_transforms.normalize_11(_lr_batch) interpolated = image_transforms.resize( _lr_batch, (_lr_batch.shape[-3] * config.scale, _lr_batch.shape[-2] * config.scale), resize_filter=pipeline.resize_filter ) # adjust hr image slightly to account for integer calculation of resizing hr_batch = image_transforms.resize( hr_batch, (interpolated.shape[-3], interpolated.shape[-2]) ) if calc_stats: metrics_res[interpolated_key]["psnr"] = tf.concat( [metrics_res[interpolated_key]["psnr"], metrics.psnr(hr_batch, interpolated, max_val=2.0)], # mav_val needs to be 2 for imgs in range [-1, 1] axis=0 ) metrics_res[interpolated_key]["ssim"] = tf.concat( [metrics_res[interpolated_key]["ssim"], metrics.ssim(hr_batch, interpolated, max_val=2.0)], axis=0 ) metrics_res[interpolated_key][psnr_y_key] = tf.concat( [metrics_res[interpolated_key][psnr_y_key], metrics.psnr_on_y(hr_batch, interpolated, max_val=2.0)], axis=0 ) if save_single: image_utils.save_single( interpolated, os.path.join(config.pic_dir, "interpolated"), f"{save_prefix}{idx}" ) image_utils.save_single( _lr_batch, os.path.join(config.pic_dir, "low_res"), f"{save_prefix}{idx}" ) result_pics[interpolated_key] = interpolated segmented = False segmented_batch = lr_batch batch_height, batch_width = _get_tensor_height_width(lr_batch) if _eligible_efficient_inference(lr_batch): segmented = True pixel_overlap = 32 segmented_batch, padding = image_utils.segment_into_patches( lr_batch, patch_width=128, patch_height=128, pixel_overlap=pixel_overlap ) segmented_batch = tf.convert_to_tensor(segmented_batch) for model_name, model in models.items(): sr_batch = _upscale(model, segmented_batch) if segmented: scaled_pixel_overlap = pixel_overlap * config.scale sr_batch = image_utils.reconstruct_from_overlapping_patches( sr_batch, image_height=(batch_height * config.scale), image_width=(batch_width * config.scale), pixel_overlap=scaled_pixel_overlap, horizontal_padding=padding[0][1] * config.scale - scaled_pixel_overlap, vertical_padding=padding[1][1] * config.scale - scaled_pixel_overlap ) if sr_batch.shape.rank == 3: sr_batch = tf.reshape(sr_batch, (1, *sr_batch.shape)) result_pics[model_name] = sr_batch if calc_stats: metrics_res[model_name]["psnr"] = tf.concat( [metrics_res[model_name]["psnr"], metrics.psnr(hr_batch, sr_batch, max_val=2.0)], axis=0 ) metrics_res[model_name][psnr_y_key] = tf.concat( [metrics_res[model_name][psnr_y_key], metrics.psnr_on_y(hr_batch, sr_batch, max_val=2.0)], axis=0 ) metrics_res[model_name]["ssim"] = tf.concat( [metrics_res[model_name]["ssim"], metrics.ssim(hr_batch, sr_batch, max_val=2.0)], axis=0 ) if save_single: image_utils.save_single( sr_batch, os.path.join(config.pic_dir, model_name, "single"), f"{save_prefix}{idx}" ) if combine_halfs: image_utils.combine_halfs( left_tensor=sr_batch, right_tensor=image_transforms.resize(lr_batch, _get_tensor_height_width(sr_batch)), left_label=model_name, right_label=interpolated_key, save_dir=os.path.join(config.pic_dir, model_name, "half"), fname=f"{save_prefix}{idx}" ) # extract psnr values of last processed batch if calc_stats: #batch_psnr = {name: _metrics["psnr"][idx * config.batch_size:] batch_psnr = {name: _metrics["psnr"][idx * config.batch_size:] for name, _metrics in metrics_res.items()} batch_psnr[ground_truth_key] = tf.constant([float("inf")] * result_pics["Ground truth"].shape[0]) batch_psnr[low_res_key] = tf.constant([-1] * result_pics["Low-Resolution"].shape[0]) batch_ssim = {name: _metrics["ssim"][idx * config.batch_size:] for name, _metrics in metrics_res.items()} batch_ssim[ground_truth_key] = tf.constant([1.0] * result_pics["Ground truth"].shape[0]) batch_ssim[low_res_key] = tf.constant([-1] * result_pics["Low-Resolution"].shape[0]) else: batch_psnr = None batch_ssim = None if save_grid: image_utils.prepare_image_grid( save_dir=os.path.join(config.pic_dir, "grids"), fname=f"{save_prefix}{idx}", low_res_key=low_res_key, psnr=batch_psnr, ssim=batch_ssim, **result_pics ) if calc_stats: for model_name, result in metrics_res.items(): if model_name not in [interpolated_key, ground_truth_key, low_res_key]: log.info(f"Average PSNR for {model_name}: {tf.reduce_mean(result['psnr'])}") log.info(f"Average PSNR on y-channel for {model_name}: {tf.reduce_mean(result[psnr_y_key])}") log.info(f"Average SSIM for {model_name}: {tf.reduce_mean(result['ssim'])}") log.info("") log.info(f"Average PSNR for {interpolated_key}: {tf.reduce_mean(metrics_res[interpolated_key]['psnr'])}") log.info(f"Average PSNR on y-channel for {interpolated_key}: {tf.reduce_mean(metrics_res[interpolated_key][psnr_y_key])}") log.info(f"Average SSIM for {interpolated_key}: {tf.reduce_mean(metrics_res[interpolated_key]['ssim'])}")
[docs]def evaluate_on_testdata(config, save_single=True, grid=False, interpolate=False, with_original=False, combine_halfs=False, pipeline=None, models=None, save_prefix="", segmentation_min_width=1000, segmentation_min_height=1000): """ Evaluates model(s) supplied via config on test data set of supplied data-pipeline. Images will be plotted for each supplied model according to supplied paramaters. :code:`save_single`, :code:`grid` and :code:`combine_halfs` can be combined and images will be produced and stored for each option. :code:`interpolate` and :code:`with_original` are options to define the appearance of :code:`grid`. :param config: | Initialized object of type :code:`ConfigUtil` from :code:`simple_sr.utils.config` module. | Models will be loaded from :code:`model_paths` in config object. :param save_single: Whether to save upscaled images as single images. :param grid: Whether so save an image grid. Supplied models will be plotted along rows, images across rows. :param interpolate: Whether to add a row with interpolated images in grid for comparision. :param with_original: Whether to plot original image next to image grid (useful if image grid contains patches) :param combine_halfs: Whether to save an image that's upscaled with supplied models on one half and bicubic on the other half. :param pipeline: | Initialized object of :code:`DataPipeline` from :code:`simple_sr.data_pipeline` package. | Will be initialized from config if None. :param models: Dict of name: model where model is of type Keras.model. | If no dict is supplied, models will be load from path in `config.model_path`. :param save_prefix: Prefix for filenames. :param segmentation_min_width: Minimal width of images to be segmentated for memory efficient inference. :param segmentation_min_height: Minimal height of images to be segmented for memory efficient inference. """ if pipeline is None: pipeline = DataPipeline.from_config(config) if models is None: models = {Path(path).stem: _load_model(path) for path in config.model_path} results = dict() test_batch_generator = pipeline.test_batch_generator(config.batch_size) for idx, (lr_batch, file_path) in enumerate(test_batch_generator): original_name = Path(file_path.numpy()[0].decode("utf-8")).parent.stem segmented = False _lr_batch = lr_batch batch_height, batch_width = _get_tensor_height_width(_lr_batch) if _eligible_efficient_inference(_lr_batch, min_width=segmentation_min_width, min_height=segmentation_min_height): segmented = True pixel_overlap = 32 _lr_batch, padding = image_utils.segment_into_patches( _lr_batch, patch_width=128, patch_height=128, pixel_overlap=pixel_overlap ) _lr_batch = tf.convert_to_tensor(_lr_batch) for name, model in models.items(): sr_batch = _upscale(model, _lr_batch) if segmented: scaled_pixel_overlap = pixel_overlap * config.scale sr_batch = image_utils.reconstruct_from_overlapping_patches( sr_batch, image_height=(batch_height * config.scale), image_width=(batch_width * config.scale), pixel_overlap=scaled_pixel_overlap, horizontal_padding=padding[0][1] * config.scale - scaled_pixel_overlap, vertical_padding=padding[1][1] * config.scale - scaled_pixel_overlap ) if sr_batch.shape.rank == 3: sr_batch = tf.reshape(sr_batch, (1, *sr_batch.shape)) results[name] = sr_batch if save_single: image_utils.save_single( sr_batch, os.path.join(config.pic_dir, original_name, "single"), f"{save_prefix}{idx}_{original_name}_{name}" ) if combine_halfs: image_utils.combine_halfs( left_tensor=sr_batch, right_tensor=image_transforms.resize(lr_batch, (_get_tensor_height_width(sr_batch))), left_label=name, right_label="interpolated", save_dir=os.path.join(config.pic_dir, original_name, "half"), fname=f"{save_prefix}{idx}_{original_name}_{name}" ) if interpolate: results["interpolated"] = image_transforms.resize( lr_batch, (_get_tensor_height_width(sr_batch)), resize_filter=tf.image.ResizeMethod.NEAREST_NEIGHBOR ) if save_single: for img in results["interpolated"]: image_utils.save_single( img, os.path.join(config.pic_dir, "interpolated"), f"{save_prefix}{idx}" ) original = None if with_original: try: original = config.test_originals[original_name] except KeyError: original = None if grid: image_utils.prepare_image_grid( save_dir=os.path.join(config.pic_dir, "grids"), fname=f"{save_prefix}{idx}_{original_name}", low_res_key=None, psnr=None, original=original, **results )
def _load_model(model_path): try: model = tf.keras.models.load_model(model_path) except OSError: print(f"Error could not locate model at path: {model_path}, exiting") sys.exit(1) return model def _get_tensor_height_width(tensor): if tensor.shape.rank == 4: return tensor.shape[1], tensor.shape[2] elif tensor.shape.rank == 3: return tensor.shape[0], tensor.shape[1] else: raise ValueError(f"Received tensor with unexpected rank: {tensor.shape.rank}") def _eligible_efficient_inference(tensor, min_width=1000, min_height=1000): if tensor.shape.rank != 3 and tensor.shape.rank != 4: return False if tensor.shape.rank == 4 and tensor.shape[0] != 1: return False batch_width, batch_height = _get_tensor_height_width(tensor) if batch_width > min_width and batch_height > min_height: return True return False def _upscale(model, lr_batch): _sr_batch = list() _lr_batch = lr_batch if _lr_batch.shape.rank == 4: _lr_batch = tf.reshape(_lr_batch, (_lr_batch.shape[0], 1, *_lr_batch.shape[1:])) for batch in _lr_batch: _sr_batch.append(model(batch, training=False)) _sr_batch = tf.convert_to_tensor(_sr_batch) return tf.reshape(_sr_batch, (-1, *_sr_batch.shape[2:])) if __name__ == "__main__": pass