Models¶
SRModel¶
-
class
SRModel(model_type, generator, generator_optimizer, generator_optimizer_config=None, discriminator=None, discriminator_optimizer=None, discriminator_optimizer_config=None, image_metrics=None, early_stop_metric='psnr', early_stop_patience=100, epoch_train_summary_writer=None, batch_train_summary_writer=None, epoch_validation_summary_writer=None, batch_validation_summary_writer=None, resnet_checkpoint=None, config=None)[source]¶ - SRModel encapsulates a Generator and optionally a Discriminator.If no Discriminator is supplied the training will be non-adversarial mode, otherwise in adversarial mode.The SRModel class is the main interface to interact with the Generator and Discriminator during training.It manages the training process by delegating data batches to Generator/Discriminator, calculates gradients and updates weights of Generator/Discriminator accordingly.Additionally the SRModel logs metrics to tensorboard/stdout, saves checkpoints and keeps track whether early stopping criterion is reached.
- Parameters
model_type – Whether to train in ‘gan’ (adversarial) mode or ‘resnet’ (non-adversarial) mode.
generator – Initialized Generator object.
generator_optimizer – Optimizer for Generator.
generator_optimizer_config –
Optimizer config for Generator, can specify things like learn-rate, learn-rate schedule etc. The optimizer config needs to be applicable to the supplied optimizer.See the tensorflow docs and examples/complex_example.yaml for more details on this.If no optimizer config is supplied the optimizer will be initialized with default values.discriminator – Optional Discriminator to train in adversarial mode.
discriminator_optimizer – Optimizer for Discriminator, needs to be supplied if discriminator is supplied.
discriminator_optimizer_config – Optimizer config for Discriminator, same things apply as for generator_optimizer_config.
image_metrics – Dictionary containing pairs of name: f(img1, img2). f will be calculated after every processed batch and logged to tensorboard. Average will be logged to stdout after each epoch. Defaults to
{"PSNR": simple_sr.utils.image.metrics.psnr}.early_stop_metric – Metric to track as trigger for early stopping.
early_stop_patience – Defines how many epochs may pass without increasing early_stop_metric.
epoch_train_summary_writer – Tensorflow summary writer for epoch training metrics.
batch_train_summary_writer – Tensorflow summary writer for batch training metrics.
epoch_validation_summary_writer – Tensorflow summary writer for epoch validation metrics.
batch_validation_summary_writer – Tensorflow summary writer for batch validation metrics.
resnet_checkpoint –
Checkpoint of pretrained resnet model, the model of the generator will be set to the restored model.
Note
The Generator still needs to be initialized and supplied to the SRModel beforehand since only the Generators (keras)-model will be restored. So there is still the need to have a Generator with initialized loss functions.
config – will be used to define save dirs, if not supplied base save dir defaults to “./”.
-
latest_checkpoint()[source]¶ Get the latest checkpoint that was saved.
- A checkpoint contains:
The current iteration step
The tracked early stop metric value
The Generator model
The Discriminator model (if training in GAN mode)
- Returns
Tensorflow checkpoint object.
-
save_model(save_path, postfix=None)[source]¶ Saves the Generator model in hdf5 format to disk.
- Parameters
save_path – Path to save Generator model to.
postfix – Optional postfix for filename, if None current epoch will be prefixed.
-
stop_early()[source]¶ Check whether early stopping criterion is reached.
- Returns
True if early stopping is reached, otherwise False.
-
generator()[source]¶ Retrieve the current Generator model.
Note: This only returns the Keras model of the Generator not the Generator object itself.
- Returns
tf.keras.model instance of current Generator model.
-
generator_optimizer()[source]¶ Retrieve the initialized and configured generator optimizer.
- Returns
Tensorflow/Keras optimizer object of Generator.
-
discriminator()[source]¶ Retrieve the current Discriminator.
Note: This only returns the Keras model of the Discriminator not the Discriminator object itself.
- Returns
tf.keras.model instance of current Discriminator model or None if training in non-adversarial mode.
-
discriminator_optimizer()[source]¶ Retrieve the initialized and configured discriminator optimizer.
- Returns
Tensorflow/Keras optimizer object of Discriminator or None if training in non-adversarial mode.
-
epoch_metrics(train=True)[source]¶ Retrieve training or validation epoch metrics of current epoch. Metrics will be reset after each epoch.
- Parameters
train – If train is true training metrics will be returned, otherwise validation metrics.
- Returns
A dictionary containing the current epochs training or validation metrics.
-
batch_metrics()[source]¶ Retrieve the current batch metrics. Since metrics will be reset after each batch there is no need to make a distinction between training and validation batch metrics.
- Returns
A dictionary containing the current batch metrics
-
epoch_history(train=True)[source]¶ Retrieve epoch history of collected metrics. :param train: Whether to retrieve training or validation epoch history. :return: List of collected metrics.
-
batch_history(train=True)[source]¶ Retrieve batch history of collected metrics. :param train: Whether to retrieve training or validation batch history. :return: List of collected metrics.
-
epoch_summary_writer(train=True)[source]¶ Retrieve the training or validation epoch summary writer.
- Parameters
train – If train is True the training epoch summary writer will be returned, otherwise the validation epoch summary writer.
- Returns
Tensorflow summary writer object.
-
batch_summary_writer(train=True)[source]¶ Retrieve the training or validation batch summary writer.
- Parameters
train – If train is True the training batch summary writer will be returned, otherwise the validation batch summary writer.
- Returns
Tensorflow summary writer object.
-
train_step(lr_batch, hr_batch)[source]¶ Train for one iteration on supplied batches. Generation of images and calculation of loss will be delegated to the Generator and afterwards the SRModel will calculate the Generators gradients and update the model of the Generator accordingly.
If training in adversarial mode, the critique and calculation of its loss will be delegated to the Discriminator and weights of the discriminator model will be updated by SRModel afterwards.
- Parameters
lr_batch – Batch of low resolution samples.
hr_batch – Batch of corresponding high resolution ground truth samples.
-
validation_step(lr_batch, hr_batch)[source]¶ Validate for one iteration on supplied batches. Only loss and image metrics will be calculated, models will not be updated.
- Parameters
lr_batch – batch of low resolution samples
hr_batch – batch of corresponding high resolution ground truth samples
-
test_and_plot(lr_batch, save_dir, step, hr_batch=None, file_path=None)[source]¶ Generate high resolution samples with generator from a supplied low resolution batch and save resulting images as image grid to monitor progress during training. Additionally each sample will be upsampled with bicubic interpolation for comparision.
This can either be done for batches from the test set where no ground truth is available or for batches from train/validation data with a corresponding ground truth. If the ground truth is supplied it will also be plotted on the image grid.
- Parameters
lr_batch – batch of low resolution sampled
save_dir – save dir for saving the resulting image grid
step – epoch to manage/identify file names of saved image grids
hr_batch – optional batch of high resolution ground truth samples
file_path – optional save dir suffix for grouping image grids in specific folders
-
after_train_batch()[source]¶ - Called after each training batch (if you’re using SimpleSR training utils).Updates number of iterations the model has trained for, logs training batch metrics to tensorboard and resets batch metrics afterwards.
-
after_validation_batch()[source]¶ - Called after each validation batch (if you’re using SimpleSR training utils).Logs validation batch metrics to tensorboard and resets batch metrics afterwards.
-
before_epoch()[source]¶ - Called before each epoch (if you’re using SimpleSR training utils).Resets epoch metrics (training and validation) and increments number of trained epochs.
-
after_epoch()[source]¶ Called after each epoch (if you’re using SimpleSR training utils).
saves (generator) model
logs epoch metrics to tensorboard and updates epoch metrics history
evaluates whether early stopping criterion is triggerd
-
after_training()[source]¶ - Called after training finishes (if you’re using SimpleSR training utils).Restores the best model and saves it with “best” postfix for identification.Plots metrics afterwards.
-
static
init(config, generator, generator_optimizer, generator_optimizer_config=None, discriminator=None, discriminator_optimizer=None, discriminator_optimizer_config=None, image_metrics=None)[source]¶ Convenience method to initialize SRModel - model type will be inferred and early stopping as well as Tensorflow summary writers will used from initialized config.
- Returns
Initialized Instance of type SRModel, ready for training.
Generator¶
-
class
Generator(upsample_factor, architecture, loss_functions, num_blocks=16, num_dense_blocks=3, num_filters=64, num_convs=4, kernel_size=3, residual_scaling=0.2, kernel_initializer=None, batch_norm=False, input_dims=None, None, pretrained_model_path=None, pretrained_model=None)[source]¶ - Generator for SRModel.The Generators job is to generate high-resolution images from supplied low-resolution images.Can either be used in adversarial mode in combination with a Discriminator or in non-GAN mode without adversarial loss.Generator keeps also track of metrics for batches and epochs.
- Parameters
upsample_factor –
Factor for increase in image resolution.E.g. if upsample_factor = 2 a 64x64 input image will be upsampled to 128x128.architecture –
Architecture for generator, currently only SRResnet and RRDB are available, for more info see:
https://arxiv.org/abs/1609.04802 (SRResnet/SRGAN)
https://arxiv.org/abs/1809.00219 (RRDB/ESRGAN)
loss_functions –
Loss functions to calculate the generators loss.Available loss functions can be found insimple_sr.utils.models.loss_functions. An arbitrary amount of loss functions can be combined.num_blocks – Number of residual building blocks for SRResnet/RRDB (see papers for info/architecture of these blocks).
num_filters – Number of filters for convolutional layers in generator architecture.
kernel_size – Kernel size for convolutional layers in generator architecture.
residual_scaling – Residual scaling for shortcut connections (will only take effect in RRDB architecture).
kernel_initializer – Weight initializer for convolutional layers.
batch_norm –
Whether to apply batch normalization in Generators architecture.This option is only applicable to SRResnet (see RRDB/ESRGAN paper, the authors conclude batch normalization might not be helpful at all times).input_dims –
Dimensions of input images to the generator.Since the generator is fully convolutional this may be (None, None) -> generator can handle arbitrary sizes.pretrained_model_path –
Path to a pretrained model, the model will be loaded and training will be resumed.Can also be used for pretraining a model in non-adversarial mode on pixel loss and then continue training in adversarial mode with different loss functions. (like VGG loss as the authors in SRGAN/ESRGAN do)pretrained_model – Already loaded keras model, same things apply as for ‘pretrained_model_path’ except that the model is expected to be already loaded
-
loss_functions()[source]¶ Retrieve registered loss functions of generator.
- Returns
List of initialized loss function objects from simple_sr.utils.models.loss_functions module.
-
formatted_epoch_metrics(train=True)[source]¶ Retrieve formatted string of epoch metrics for logging
- Parameters
train – request either train or validation metrics
- Returns
formatted metrics string of training or validation metrics, depending on supplied parameter
-
generate(lr_batch, training=True)[source]¶ Generate batch of high-resolution images based on supplied low-resolution training images.
- Parameters
lr_batch – Low resolution input images.
training – Whether currently training or validating (batch normalization will be off during validation).
- Returns
Tensor containing upsampled images.
-
calculate_train_loss(sr_batch, hr_batch, sr_critic, hr_critic)[source]¶ - Delegates calculation of loss to loss functions and calculates total training loss.Loss functions will record training loss metrics.
- Parameters
sr_batch – Batch of generated high-resolution samples.
hr_batch – Batch corresponding high-resolution ground truth samples.
sr_critic – Critique of discriminator for synthetic/generated data training samples (only applicable if training in GAN mode, otherwise will be None).
hr_critic – Critique of discriminator for real data training samples (only applicable if training in GAN mode, otherwise will be None).
- Returns
Integer representing total loss for training batch.
-
calculate_validation_loss(sr_batch, hr_batch, sr_critic, hr_critic)[source]¶ - Delegates calculation of validation loss to loss functions and calculates total validation loss.Loss functions will record training loss metrics.
- Parameters
sr_batch – Batch of generated high-resolution samples.
hr_batch – Batch of corresponding high-resolution ground truth samples.
sr_critic – Critique of discriminator for synthetic/generated data validation samples (only applicable if training in GAN mode, otherwise will be None).
hr_critic – Critique of discriminator for real data validation samples (only applicable if training in GAN mode, otherwise will be None).
- Returns
Integer representing total loss for validation batch.
-
static
srresnet(upsample_factor, loss_function=None, num_blocks=16, num_filters=64, kernel_size=3, batch_norm=True, input_dims=None, None, pretrained_model_path=None, pretrained_model=None)[source]¶ - Convenience method for initializing SRResnet Generator in non-adversarial mode.Default parameters are set according to SRResnet/SRGAN paper (https://arxiv.org/abs/1609.04802).
- Returns
Initialized Generator instance.
-
static
rrdb(upsample_factor, loss_functions=<class 'simple_sr.utils.models.loss_functions.mean_absolute_error.MeanAbsoluteError'>, loss_weight=1.0, num_blocks=16, num_dense_blocks=3, num_filters=64, num_convs=4, kernel_size=3, residual_scaling=0.2, kernel_initializer=None, batch_norm=False, input_dims=(None, None), pretrained_model_path=None, pretrained_model=None)[source]¶ - Convenience method for initializing RRDB Generator in non-adversarial mode.Default parameters are set according to RRDB/ESRGAN paper (https://arxiv.org/abs/1809.00219).
- Returns
Initialized Generator instance.
-
static
srgan_generator(upsample_factor, vgg_loss, vgg_layer, vgg_feature_scaling=0.0784313725490196, vgg_loss_weight=1.0, adversarial_loss_weight=0.001, num_blocks=16, num_filters=64, kernel_size=3, batch_norm=True, input_dims=None, None, pretrained_model_path=None, pretrained_model=None)[source]¶ - Convenience method for initializing SRResnet Generator in adversarial mode.Default parameters are set according to SRResnet/SRGAN paper (https://arxiv.org/abs/1609.04802).
- Returns
Initialized Generator instance.
-
static
esrgan_generator(upsample_factor, vgg_layer='block5_conv4', vgg_feature_scaling=1.0, vgg_loss_weight=1.0, adversarial_loss_weight=0.005, l1_loss_weight=0.01, num_blocks=16, num_dense_blocks=3, num_filters=64, num_convs=4, kernel_size=3, input_dims=None, None, pretrained_model_path=None, pretrained_model=None)[source]¶ - Convenience method for initializing RRDB Generator in adversarial mode.Default parameters are set according to RRDB/ESRGAN paper (https://arxiv.org/abs/1809.00219).
- Returns
Initialized Generator instance.
Discriminator¶
-
class
Discriminator(loss_function, relativistic, label_smoothing=False, smoothing_offset=0.3, num_filters=64, alpha=0.2, kernel_size=3, momentum=0.8, initializer=None, input_dims=None, None)[source]¶ - Discriminator to train SRModel in adversarial mode.The Discriminators job is to critique synthetic and real samples.Also calculates its own loss and keeps track of metrics for epochs and batches.
- Parameters
loss_function –
Loss function to calculate loss of Discriminator, currently available loss functions are standard GAN discriminator loss and relativistic average GAN discriminator loss.More info on different types of losses for GANs can be found here: https://arxiv.org/abs/1807.00734.relativistic –
Whether the Discriminator is relativistic, if relativistic is true there won’t be a final sigmoid layer as the last layer of the discriminators architecture.label_smoothing –
Whether to apply label smoothing, this can help to stabilize the training by making the Discriminators job a little harder.if false target labels for the discriminator will be either 0 or 1 for synthetic and real samples respectively
if true target labels will in range [0, smoothing_offset] for fake labels and [1 - smoothing_offset, 1 + smoothing_offset] for real labels
For more tweaks to make GANs more stable see this github repo: https://github.com/soumith/ganhacks
smoothing_offset – Sets upper and lower bound for random noise in target labels (if label_smoothing is True).
num_filters – Number of filters in convolutional layers of discriminators architecture.
alpha – Negative slope coefficient for Leaky ReLU activation function.
kernel_size – Kernel size in convolutional layers of discriminators architecture.
momentum – Momentum for batch normalization.
initializer – Initializer for weights initialization of discriminator.
input_dims – Dimensions of input images to the discriminator.
-
loss_function()[source]¶ Retrieve loss function of Discriminator.
- Returns
Initialized loss function object from simple_sr.utils.models.loss_functions module
-
formatted_epoch_metrics(train=True)[source]¶ Return formatted string of epoch metrics for logging.
- Parameters
train – Request either training or validation metrics.
- Returns
Formatted metrics string of training or validation metrics, depending on supplied parameter.
-
critic_train_batch(sr_batch, hr_batch)[source]¶ Critique synthetic and real data training batches and keep track of metrics.
- Parameters
sr_batch – Batch of synthetic training data.
hr_batch – Batch of real training data.
- Returns
Tuple of tensors containing the likelihood of samples being real for sr_batch and hr_batch respectively.
-
critic_validation_batch(sr_batch, hr_batch)[source]¶ Critique synthetic and real data validation batches and keep track of metrics.
- Parameters
sr_batch – Batch of synthetic validation data.
hr_batch – Batch of real validation data.
- Returns
Tuple of tensors containing the likelihood of samples being real for sr_batch and hr_batch respectively.
-
calculate_train_loss(sr_critic, hr_critic)[source]¶ - Delegates calculation of training loss to loss function.Target labels for loss calculation will generated according to parameters label_smoothing and smoothing_offset.
- Parameters
sr_critic – Discriminators critique of a synthetic training data batch.
hr_critic – Discriminators` critique of a real training data batch.
- Returns
Loss calculated by discriminators loss function.
-
calculate_validation_loss(sr_critic, hr_critic)[source]¶ - Delegates calculation of validation loss to loss function.Target labels for loss calculation will generated according to parameters label_smoothing and smoothing_offset.
- Parameters
sr_critic – Discriminators critique of a synthetic validation data batch.
hr_critic – Discriminators critique of a real validation data batch.
- Returns
Loss calculated by discriminators loss function.
-
static
initialize_relativistic(weighted_loss=False, loss_weight=1.0, num_filters=64, alpha=0.2, kernel_size=3, momentum=0.8, initializer=None, input_dims=None, None)[source]¶ Convenience method to initialize a relativistic average GAN discriminator with corresponding loss function.
- Parameters
weighted_loss – Whether loss function should weighted.
loss_weight – Factor for weighted loss.
num_filters – Number of filters in convolutional layers of discriminators architecture.
alpha – Negative slope coefficient for Leaky ReLU activation function.
kernel_size – Kernel size in convolutional layers of discriminators architecture.
momentum – Momentum for batch normalization.
initializer – Initializer for weights initialization of discriminator.
input_dims – Dimensions of input images to the discriminator.
- Returns
Initialized Discriminator object.
-
static
initialize_standard(weighted_loss=False, loss_weight=1.0, label_smoothing=False, smoothing_offset=0.0, num_filters=64, alpha=0.2, kernel_size=3, momentum=0.8, initializer=None, input_dims=None, None)[source]¶ Convenience method to initialize a standard GAN discriminator with corresponding loss function.
- Parameters
weighted_loss – Whether loss function should weighted.
loss_weight – Factor for weighted loss.
label_smoothing –
Whether to apply label smoothing, this can help to stabilize the training by making the Discriminators job a little harder.if false target labels for the discriminator will be either 0 or 1 for synthetic and real samples respectively
if true target labels will in range [0, smoothing_offset] for fake labels and [1 - smoothing_offset, 1 + smoothing_offset] for real labels
For more tweaks to make GANs more stable see this github repo: https://github.com/soumith/ganhacks
smoothing_offset – Sets upper and lower bound for random noise in target labels.
num_filters – Number of filters in convolutional layers of discriminators architecture.
alpha – Negative slope coefficient for Leaky ReLU activation function.
kernel_size – Kernel size in convolutional layers of discriminators architecture.
momentum – Momentum for batch normalization.
initializer – Initializer for weights initialization of discriminator.
input_dims – Dimensions of input images to the discriminator.
- Returns
Initialized Discriminator object.