Source code for simple_sr.utils.models.loss_functions.vgg_loss

import tensorflow as tf
import os
import logging
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.python.keras.models import Model

from simple_sr.utils.models import model_builder
from simple_sr.utils import logger

log = logging.getLogger(logger.LIB_LOGGER).getChild(__name__)


[docs]class VGGLoss: """ | Loss function to calculate loss for Generator based on differences between activations in feature maps of a reference network. The reference network used here can be either the VGG19 (default) or VGG16 network. | Loss functions of this type are usually called Perceptual Loss and aim to drive the Generator towards producing visually more pleasing results, compared to pixelwise loss functions. See https://arxiv.org/abs/1609.04802 for more info on perceptual loss. | After initialization the `VGGLoss` object can be used as a functor for loss calculation: .. code:: vgg_loss = VGGLoss(output_layers="block5_conv4", after_activation=False, feature_scale=1.0, loss_weight=1.0, total_variation_loss=False) ... loss = vgg_loss(hr_batch, sr_batch, hr_critic, sr_critic) :param output_layers: | Layers of Vgg19 network to compare feature maps of synthesized and real data samples on. May be a list of layers, the loss of each layer will be summed up. | The names of each layer can be found here: https://github.com/keras-team/keras-applications/blob/master/keras_applications/vgg19.py :param feature_scale: Scaling factor for each feature map. :param loss_weight: | Factor to weight the | This can be useful if VGG loss is combined with an additional pixel-wise loss, | which might be orders of magnitudes higher and could therefor outweigh VGG loss. :param total_variation_loss: | Whether to use an additional total variation loss as used in some variants of SRGAN (https://arxiv.org/abs/1609.04802). :param total_varation_weight: Factor to weight total variation loss. :param after_activation: | Whether calculate loss on feature maps after activation function or before. | See ESRGAN paper (https://arxiv.org/abs/1809.00219) for an explanation | of the benefits and downsides. :param track_metrics: whether the class should update the supplied metrics dictionaries. :param vgg16: Whether to use VGG16 instead of VGG19. :param custom_weights: Whether to initialize VGG network with custom weights. :param custom_weights_path: Path to custom weights file (.h5 file). """ def __init__(self, output_layers, feature_scale=1.0, loss_weight=1.0, total_variation_loss=False, total_varation_weight=2*10e-8, after_activation=True, track_metrics=True, vgg16=False, custom_weights=False, custom_weights_path=None): if vgg16: self.preprocess_func = tf.keras.applications.vgg16.preprocess_input else: self.preprocess_func = tf.keras.applications.vgg19.preprocess_input self.name = "vgg_loss" self.mse = tf.keras.losses.MeanSquaredError() self.track_metrics = track_metrics self.feature_scale = feature_scale self.loss_weight = loss_weight self.weighted = False if self.loss_weight != 1.0: self.weighted = True self.total_variation_loss = total_variation_loss self.total_variation_weight = total_varation_weight self.after_activation = after_activation self.loss = 0 self.weighted_loss = 0 if self.after_activation: log.debug("requested vgg features after activation - building standard vgg network") net = VGG19 if vgg16: net = VGG16 vgg = net( input_shape=(None, None, 3), include_top=False, weights="imagenet", pooling=None ) if custom_weights: if custom_weights_path is None: raise ValueError("no custom weights path supplied") if not os.path.isfile(custom_weights_path): raise ValueError("can't locate custom weights") vgg.load_weights(custom_weights_path) else: # build custom vgg to allow loss calculation before activation log.debug("requested vgg features before activation - building custom vgg network") net = model_builder.build_vgg_19 if vgg16: net = model_builder.build_vgg_16 vgg = net(input_shape=(None, None, 3), load_custom_weights=custom_weights, custom_weights_path=custom_weights_path) vgg.trainable = False self.output_layers = output_layers if type(self.output_layers) is not list: self.output_layers = [self.output_layers] outputs = [vgg.get_layer(layer_name).output for layer_name in self.output_layers] self.model = Model(inputs=[vgg.input], outputs=outputs) log.debug(f"initialized vgg loss - output layers: {self.output_layers}, " f"feature scaling: {self.feature_scale}, " f"loss_weight: {self.loss_weight}, after activation: {self.after_activation}")
[docs] @tf.function def __call__(self, hr_batch, sr_batch, hr_critic, sr_critic, batch_metrics, epoch_metrics, denormalize=True): """ Calculate vgg loss for a batch of High-Resolution real data samples and synthesized High-Resolution samples. .. important:: Pixel values for VGG need to be in [0, 255]. So you can either supply batches with pixels in that range, or if your pixels are in [-1, 1] you can use the `denormalize` flag for conversion. Any other combination will not work. .. note:: The parameters `hr_critique` and `sr_critique` will not be used/needed for calculation of vgg loss, but the function needs to adhere to the (implicit) Generator loss function interface. :param hr_batch: | Tensor of real data High-Resolution samples. | Pixel values either need to be in [0, 255] or [-1, 1] with `denormalize=True`. :param sr_batch: | Tensor of synthesized High-Resolution samples with equal shape as `hr_batch`. | Pixel values either need to be in [0, 255] or [-1, 1] with `denormalize=True`. :param hr_critic: Not needed, may be `None`. :param sr_critic: Not needed, may be `None`. :param batch_metrics: Optional dictionary to store batch metrics. :param epoch_metrics: Optional dictionary to store epoch metrics. :param denormalize: Whether to denormalize from [-1, 1] to [0, 255]. :return: (Weighted) vgg loss for batch. """ if denormalize: hr_batch = (hr_batch + 1) * 127.5 sr_batch = (sr_batch + 1) * 127.5 hr_preprocessed = self.preprocess_func(hr_batch) sr_preprocessed = self.preprocess_func(sr_batch) hr_features = self.model(hr_preprocessed, training=False) self._hr_feats = hr_features if type(hr_features) is not list: hr_features = [hr_features] scaled_hr_features = [hr_feature * self.feature_scale for hr_feature in hr_features] sr_features = self.model(sr_preprocessed, training=False) self._sr_feats = sr_features if type(sr_features) is not list: sr_features = [sr_features] scaled_sr_features = [sr_feature * self.feature_scale for sr_feature in sr_features] self.loss = 0 for hr_feature, sr_feature in zip(scaled_hr_features, scaled_sr_features): self.loss += self.mse(hr_feature, sr_feature) * self.loss_weight if self.total_variation_loss: self.loss += self.total_variation_weight * tf.reduce_sum( tf.image.total_variation(sr_batch) ) # TODO: fix weighted loss tracking if self.track_metrics: batch_metrics[self.name](self.loss) epoch_metrics[self.name](self.loss) try: batch_metrics[f"weighted_{self.name}"](self.weighted_loss) epoch_metrics[f"weighted_{self.name}"](self.weighted_loss) except KeyError: pass return self.loss
def visualize_feature_maps(self, picture, denormalize=True): _picture = picture if denormalize: _picture = (picture + 1) * 127.5 preprocessed = preprocess_input(_picture) features = self.model(preprocessed, training=False) return features def __str__(self): return f"## Vgg Loss\n" \ f"output layers: {self.output_layers}\n" \ f"feature scaling: {self.feature_scale}\n" \ f"after activation: {self.after_activation}\n" \ f"loss weight: {self.loss_weight}\n"\ f"total variation loss: {self.total_variation_loss}\n"\ f"total variation loss weight: {self.total_variation_weight}\n"
if __name__ == "__main__": pass