Models

SRModel

class SRModel(model_type, generator, generator_optimizer, generator_optimizer_config=None, discriminator=None, discriminator_optimizer=None, discriminator_optimizer_config=None, image_metrics=None, early_stop_metric='psnr', early_stop_patience=100, epoch_train_summary_writer=None, batch_train_summary_writer=None, epoch_validation_summary_writer=None, batch_validation_summary_writer=None, resnet_checkpoint=None, config=None)[source]
SRModel encapsulates a Generator and optionally a Discriminator.
If no Discriminator is supplied the training will be non-adversarial mode, otherwise in adversarial mode.
The SRModel class is the main interface to interact with the Generator and Discriminator during training.
It manages the training process by delegating data batches to Generator/Discriminator, calculates gradients and updates weights of Generator/Discriminator accordingly.
Additionally the SRModel logs metrics to tensorboard/stdout, saves checkpoints and keeps track whether early stopping criterion is reached.
Parameters
  • model_type – Whether to train in ‘gan’ (adversarial) mode or ‘resnet’ (non-adversarial) mode.

  • generator – Initialized Generator object.

  • generator_optimizer – Optimizer for Generator.

  • generator_optimizer_config

    Optimizer config for Generator, can specify things like learn-rate, learn-rate schedule etc. The optimizer config needs to be applicable to the supplied optimizer.
    See the tensorflow docs and examples/complex_example.yaml for more details on this.
    If no optimizer config is supplied the optimizer will be initialized with default values.

  • discriminator – Optional Discriminator to train in adversarial mode.

  • discriminator_optimizer – Optimizer for Discriminator, needs to be supplied if discriminator is supplied.

  • discriminator_optimizer_config – Optimizer config for Discriminator, same things apply as for generator_optimizer_config.

  • image_metrics – Dictionary containing pairs of name: f(img1, img2). f will be calculated after every processed batch and logged to tensorboard. Average will be logged to stdout after each epoch. Defaults to {"PSNR": simple_sr.utils.image.metrics.psnr}.

  • early_stop_metric – Metric to track as trigger for early stopping.

  • early_stop_patience – Defines how many epochs may pass without increasing early_stop_metric.

  • epoch_train_summary_writer – Tensorflow summary writer for epoch training metrics.

  • batch_train_summary_writer – Tensorflow summary writer for batch training metrics.

  • epoch_validation_summary_writer – Tensorflow summary writer for epoch validation metrics.

  • batch_validation_summary_writer – Tensorflow summary writer for batch validation metrics.

  • resnet_checkpoint

    Checkpoint of pretrained resnet model, the model of the generator will be set to the restored model.

    Note

    The Generator still needs to be initialized and supplied to the SRModel beforehand since only the Generators (keras)-model will be restored. So there is still the need to have a Generator with initialized loss functions.

  • config – will be used to define save dirs, if not supplied base save dir defaults to “./”.

latest_checkpoint()[source]

Get the latest checkpoint that was saved.

A checkpoint contains:
  • The current iteration step

  • The tracked early stop metric value

  • The Generator model

  • The Discriminator model (if training in GAN mode)

Returns

Tensorflow checkpoint object.

save_model(save_path, postfix=None)[source]

Saves the Generator model in hdf5 format to disk.

Parameters
  • save_path – Path to save Generator model to.

  • postfix – Optional postfix for filename, if None current epoch will be prefixed.

stop_early()[source]

Check whether early stopping criterion is reached.

Returns

True if early stopping is reached, otherwise False.

generator()[source]

Retrieve the current Generator model.

Note: This only returns the Keras model of the Generator not the Generator object itself.

Returns

tf.keras.model instance of current Generator model.

generator_optimizer()[source]

Retrieve the initialized and configured generator optimizer.

Returns

Tensorflow/Keras optimizer object of Generator.

discriminator()[source]

Retrieve the current Discriminator.

Note: This only returns the Keras model of the Discriminator not the Discriminator object itself.

Returns

tf.keras.model instance of current Discriminator model or None if training in non-adversarial mode.

discriminator_optimizer()[source]

Retrieve the initialized and configured discriminator optimizer.

Returns

Tensorflow/Keras optimizer object of Discriminator or None if training in non-adversarial mode.

epoch_metrics(train=True)[source]

Retrieve training or validation epoch metrics of current epoch. Metrics will be reset after each epoch.

Parameters

train – If train is true training metrics will be returned, otherwise validation metrics.

Returns

A dictionary containing the current epochs training or validation metrics.

batch_metrics()[source]

Retrieve the current batch metrics. Since metrics will be reset after each batch there is no need to make a distinction between training and validation batch metrics.

Returns

A dictionary containing the current batch metrics

epoch_history(train=True)[source]

Retrieve epoch history of collected metrics. :param train: Whether to retrieve training or validation epoch history. :return: List of collected metrics.

batch_history(train=True)[source]

Retrieve batch history of collected metrics. :param train: Whether to retrieve training or validation batch history. :return: List of collected metrics.

epoch_summary_writer(train=True)[source]

Retrieve the training or validation epoch summary writer.

Parameters

train – If train is True the training epoch summary writer will be returned, otherwise the validation epoch summary writer.

Returns

Tensorflow summary writer object.

batch_summary_writer(train=True)[source]

Retrieve the training or validation batch summary writer.

Parameters

train – If train is True the training batch summary writer will be returned, otherwise the validation batch summary writer.

Returns

Tensorflow summary writer object.

train_step(lr_batch, hr_batch)[source]

Train for one iteration on supplied batches. Generation of images and calculation of loss will be delegated to the Generator and afterwards the SRModel will calculate the Generators gradients and update the model of the Generator accordingly.

If training in adversarial mode, the critique and calculation of its loss will be delegated to the Discriminator and weights of the discriminator model will be updated by SRModel afterwards.

Parameters
  • lr_batch – Batch of low resolution samples.

  • hr_batch – Batch of corresponding high resolution ground truth samples.

validation_step(lr_batch, hr_batch)[source]

Validate for one iteration on supplied batches. Only loss and image metrics will be calculated, models will not be updated.

Parameters
  • lr_batch – batch of low resolution samples

  • hr_batch – batch of corresponding high resolution ground truth samples

test_and_plot(lr_batch, save_dir, step, hr_batch=None, file_path=None)[source]

Generate high resolution samples with generator from a supplied low resolution batch and save resulting images as image grid to monitor progress during training. Additionally each sample will be upsampled with bicubic interpolation for comparision.

This can either be done for batches from the test set where no ground truth is available or for batches from train/validation data with a corresponding ground truth. If the ground truth is supplied it will also be plotted on the image grid.

Parameters
  • lr_batch – batch of low resolution sampled

  • save_dir – save dir for saving the resulting image grid

  • step – epoch to manage/identify file names of saved image grids

  • hr_batch – optional batch of high resolution ground truth samples

  • file_path – optional save dir suffix for grouping image grids in specific folders

after_train_batch()[source]
Called after each training batch (if you’re using SimpleSR training utils).
Updates number of iterations the model has trained for, logs training batch metrics to tensorboard and resets batch metrics afterwards.
after_validation_batch()[source]
Called after each validation batch (if you’re using SimpleSR training utils).
Logs validation batch metrics to tensorboard and resets batch metrics afterwards.
before_epoch()[source]
Called before each epoch (if you’re using SimpleSR training utils).
Resets epoch metrics (training and validation) and increments number of trained epochs.
after_epoch()[source]

Called after each epoch (if you’re using SimpleSR training utils).

  • saves (generator) model

  • logs epoch metrics to tensorboard and updates epoch metrics history

  • evaluates whether early stopping criterion is triggerd

after_training()[source]
Called after training finishes (if you’re using SimpleSR training utils).
Restores the best model and saves it with “best” postfix for identification.
Plots metrics afterwards.
formatted_epoch_metrics()[source]

Retrieve formatted epoch metrics/losses for logging

static init(config, generator, generator_optimizer, generator_optimizer_config=None, discriminator=None, discriminator_optimizer=None, discriminator_optimizer_config=None, image_metrics=None)[source]

Convenience method to initialize SRModel - model type will be inferred and early stopping as well as Tensorflow summary writers will used from initialized config.

Returns

Initialized Instance of type SRModel, ready for training.

Generator

class Generator(upsample_factor, architecture, loss_functions, num_blocks=16, num_dense_blocks=3, num_filters=64, num_convs=4, kernel_size=3, residual_scaling=0.2, kernel_initializer=None, batch_norm=False, input_dims=None, None, pretrained_model_path=None, pretrained_model=None)[source]
Generator for SRModel.
The Generators job is to generate high-resolution images from supplied low-resolution images.
Can either be used in adversarial mode in combination with a Discriminator or in non-GAN mode without adversarial loss.
Generator keeps also track of metrics for batches and epochs.
Parameters
  • upsample_factor

    Factor for increase in image resolution.
    E.g. if upsample_factor = 2 a 64x64 input image will be upsampled to 128x128.

  • architecture

    Architecture for generator, currently only SRResnet and RRDB are available, for more info see:

  • loss_functions

    Loss functions to calculate the generators loss.
    Available loss functions can be found in simple_sr.utils.models.loss_functions. An arbitrary amount of loss functions can be combined.

  • num_blocks – Number of residual building blocks for SRResnet/RRDB (see papers for info/architecture of these blocks).

  • num_filters – Number of filters for convolutional layers in generator architecture.

  • kernel_size – Kernel size for convolutional layers in generator architecture.

  • residual_scaling – Residual scaling for shortcut connections (will only take effect in RRDB architecture).

  • kernel_initializer – Weight initializer for convolutional layers.

  • batch_norm

    Whether to apply batch normalization in Generators architecture.
    This option is only applicable to SRResnet (see RRDB/ESRGAN paper, the authors conclude batch normalization might not be helpful at all times).

  • input_dims

    Dimensions of input images to the generator.
    Since the generator is fully convolutional this may be (None, None) -> generator can handle arbitrary sizes.

  • pretrained_model_path

    Path to a pretrained model, the model will be loaded and training will be resumed.
    Can also be used for pretraining a model in non-adversarial mode on pixel loss and then continue training in adversarial mode with different loss functions. (like VGG loss as the authors in SRGAN/ESRGAN do)

  • pretrained_model – Already loaded keras model, same things apply as for ‘pretrained_model_path’ except that the model is expected to be already loaded

model()[source]

Retrieve the generators model.

Returns

Instance of type tf.keras.model.

loss_functions()[source]

Retrieve registered loss functions of generator.

Returns

List of initialized loss function objects from simple_sr.utils.models.loss_functions module.

formatted_epoch_metrics(train=True)[source]

Retrieve formatted string of epoch metrics for logging

Parameters

train – request either train or validation metrics

Returns

formatted metrics string of training or validation metrics, depending on supplied parameter

generate(lr_batch, training=True)[source]

Generate batch of high-resolution images based on supplied low-resolution training images.

Parameters
  • lr_batch – Low resolution input images.

  • training – Whether currently training or validating (batch normalization will be off during validation).

Returns

Tensor containing upsampled images.

calculate_train_loss(sr_batch, hr_batch, sr_critic, hr_critic)[source]
Delegates calculation of loss to loss functions and calculates total training loss.
Loss functions will record training loss metrics.
Parameters
  • sr_batch – Batch of generated high-resolution samples.

  • hr_batch – Batch corresponding high-resolution ground truth samples.

  • sr_critic – Critique of discriminator for synthetic/generated data training samples (only applicable if training in GAN mode, otherwise will be None).

  • hr_critic – Critique of discriminator for real data training samples (only applicable if training in GAN mode, otherwise will be None).

Returns

Integer representing total loss for training batch.

calculate_validation_loss(sr_batch, hr_batch, sr_critic, hr_critic)[source]
Delegates calculation of validation loss to loss functions and calculates total validation loss.
Loss functions will record training loss metrics.
Parameters
  • sr_batch – Batch of generated high-resolution samples.

  • hr_batch – Batch of corresponding high-resolution ground truth samples.

  • sr_critic – Critique of discriminator for synthetic/generated data validation samples (only applicable if training in GAN mode, otherwise will be None).

  • hr_critic – Critique of discriminator for real data validation samples (only applicable if training in GAN mode, otherwise will be None).

Returns

Integer representing total loss for validation batch.

static srresnet(upsample_factor, loss_function=None, num_blocks=16, num_filters=64, kernel_size=3, batch_norm=True, input_dims=None, None, pretrained_model_path=None, pretrained_model=None)[source]
Convenience method for initializing SRResnet Generator in non-adversarial mode.
Default parameters are set according to SRResnet/SRGAN paper (https://arxiv.org/abs/1609.04802).
Returns

Initialized Generator instance.

static rrdb(upsample_factor, loss_functions=<class 'simple_sr.utils.models.loss_functions.mean_absolute_error.MeanAbsoluteError'>, loss_weight=1.0, num_blocks=16, num_dense_blocks=3, num_filters=64, num_convs=4, kernel_size=3, residual_scaling=0.2, kernel_initializer=None, batch_norm=False, input_dims=(None, None), pretrained_model_path=None, pretrained_model=None)[source]
Convenience method for initializing RRDB Generator in non-adversarial mode.
Default parameters are set according to RRDB/ESRGAN paper (https://arxiv.org/abs/1809.00219).
Returns

Initialized Generator instance.

static srgan_generator(upsample_factor, vgg_loss, vgg_layer, vgg_feature_scaling=0.0784313725490196, vgg_loss_weight=1.0, adversarial_loss_weight=0.001, num_blocks=16, num_filters=64, kernel_size=3, batch_norm=True, input_dims=None, None, pretrained_model_path=None, pretrained_model=None)[source]
Convenience method for initializing SRResnet Generator in adversarial mode.
Default parameters are set according to SRResnet/SRGAN paper (https://arxiv.org/abs/1609.04802).
Returns

Initialized Generator instance.

static esrgan_generator(upsample_factor, vgg_layer='block5_conv4', vgg_feature_scaling=1.0, vgg_loss_weight=1.0, adversarial_loss_weight=0.005, l1_loss_weight=0.01, num_blocks=16, num_dense_blocks=3, num_filters=64, num_convs=4, kernel_size=3, input_dims=None, None, pretrained_model_path=None, pretrained_model=None)[source]
Convenience method for initializing RRDB Generator in adversarial mode.
Default parameters are set according to RRDB/ESRGAN paper (https://arxiv.org/abs/1809.00219).
Returns

Initialized Generator instance.

static from_yaml(config_yaml)[source]

Initialize generator from supplied yaml config

Parameters

config_yaml – yaml file containing specification for generator, see examples for yaml structure

Returns

Initialized Generator instance.

Discriminator

class Discriminator(loss_function, relativistic, label_smoothing=False, smoothing_offset=0.3, num_filters=64, alpha=0.2, kernel_size=3, momentum=0.8, initializer=None, input_dims=None, None)[source]
Discriminator to train SRModel in adversarial mode.
The Discriminators job is to critique synthetic and real samples.
Also calculates its own loss and keeps track of metrics for epochs and batches.
Parameters
  • loss_function

    Loss function to calculate loss of Discriminator, currently available loss functions are standard GAN discriminator loss and relativistic average GAN discriminator loss.
    More info on different types of losses for GANs can be found here: https://arxiv.org/abs/1807.00734.

  • relativistic

    Whether the Discriminator is relativistic, if relativistic is true there won’t be a final sigmoid layer as the last layer of the discriminators architecture.

  • label_smoothing

    Whether to apply label smoothing, this can help to stabilize the training by making the Discriminators job a little harder.
    • if false target labels for the discriminator will be either 0 or 1 for synthetic and real samples respectively

    • if true target labels will in range [0, smoothing_offset] for fake labels and [1 - smoothing_offset, 1 + smoothing_offset] for real labels

    For more tweaks to make GANs more stable see this github repo: https://github.com/soumith/ganhacks

  • smoothing_offset – Sets upper and lower bound for random noise in target labels (if label_smoothing is True).

  • num_filters – Number of filters in convolutional layers of discriminators architecture.

  • alpha – Negative slope coefficient for Leaky ReLU activation function.

  • kernel_size – Kernel size in convolutional layers of discriminators architecture.

  • momentum – Momentum for batch normalization.

  • initializer – Initializer for weights initialization of discriminator.

  • input_dims – Dimensions of input images to the discriminator.

model()[source]

Retrieve the Discriminators model.

Returns

Instance of type tf.keras.model.

loss_function()[source]

Retrieve loss function of Discriminator.

Returns

Initialized loss function object from simple_sr.utils.models.loss_functions module

reset_epoch_metrics()[source]

Reset all training and validation metrics.

reset_batch_metrics()[source]

Reset all batch metrics.

formatted_epoch_metrics(train=True)[source]

Return formatted string of epoch metrics for logging.

Parameters

train – Request either training or validation metrics.

Returns

Formatted metrics string of training or validation metrics, depending on supplied parameter.

critic_train_batch(sr_batch, hr_batch)[source]

Critique synthetic and real data training batches and keep track of metrics.

Parameters
  • sr_batch – Batch of synthetic training data.

  • hr_batch – Batch of real training data.

Returns

Tuple of tensors containing the likelihood of samples being real for sr_batch and hr_batch respectively.

critic_validation_batch(sr_batch, hr_batch)[source]

Critique synthetic and real data validation batches and keep track of metrics.

Parameters
  • sr_batch – Batch of synthetic validation data.

  • hr_batch – Batch of real validation data.

Returns

Tuple of tensors containing the likelihood of samples being real for sr_batch and hr_batch respectively.

calculate_train_loss(sr_critic, hr_critic)[source]
Delegates calculation of training loss to loss function.
Target labels for loss calculation will generated according to parameters label_smoothing and smoothing_offset.
Parameters
  • sr_criticDiscriminators critique of a synthetic training data batch.

  • hr_critic – Discriminators` critique of a real training data batch.

Returns

Loss calculated by discriminators loss function.

calculate_validation_loss(sr_critic, hr_critic)[source]
Delegates calculation of validation loss to loss function.
Target labels for loss calculation will generated according to parameters label_smoothing and smoothing_offset.
Parameters
  • sr_criticDiscriminators critique of a synthetic validation data batch.

  • hr_criticDiscriminators critique of a real validation data batch.

Returns

Loss calculated by discriminators loss function.

static initialize_relativistic(weighted_loss=False, loss_weight=1.0, num_filters=64, alpha=0.2, kernel_size=3, momentum=0.8, initializer=None, input_dims=None, None)[source]

Convenience method to initialize a relativistic average GAN discriminator with corresponding loss function.

Parameters
  • weighted_loss – Whether loss function should weighted.

  • loss_weight – Factor for weighted loss.

  • num_filters – Number of filters in convolutional layers of discriminators architecture.

  • alpha – Negative slope coefficient for Leaky ReLU activation function.

  • kernel_size – Kernel size in convolutional layers of discriminators architecture.

  • momentum – Momentum for batch normalization.

  • initializer – Initializer for weights initialization of discriminator.

  • input_dims – Dimensions of input images to the discriminator.

Returns

Initialized Discriminator object.

static initialize_standard(weighted_loss=False, loss_weight=1.0, label_smoothing=False, smoothing_offset=0.0, num_filters=64, alpha=0.2, kernel_size=3, momentum=0.8, initializer=None, input_dims=None, None)[source]

Convenience method to initialize a standard GAN discriminator with corresponding loss function.

Parameters
  • weighted_loss – Whether loss function should weighted.

  • loss_weight – Factor for weighted loss.

  • label_smoothing

    Whether to apply label smoothing, this can help to stabilize the training by making the Discriminators job a little harder.
    • if false target labels for the discriminator will be either 0 or 1 for synthetic and real samples respectively

    • if true target labels will in range [0, smoothing_offset] for fake labels and [1 - smoothing_offset, 1 + smoothing_offset] for real labels

    For more tweaks to make GANs more stable see this github repo: https://github.com/soumith/ganhacks

  • smoothing_offset – Sets upper and lower bound for random noise in target labels.

  • num_filters – Number of filters in convolutional layers of discriminators architecture.

  • alpha – Negative slope coefficient for Leaky ReLU activation function.

  • kernel_size – Kernel size in convolutional layers of discriminators architecture.

  • momentum – Momentum for batch normalization.

  • initializer – Initializer for weights initialization of discriminator.

  • input_dims – Dimensions of input images to the discriminator.

Returns

Initialized Discriminator object.

static from_yaml(config_yaml)[source]

Initialize discriminator from supplied yaml config.

Parameters

config_yaml – yaml file containing specification for discriminator, see examples for yaml structure

Returns

Initalized discriminator object.