Loss Functions

Generator Loss Functions

class MeanAbsoluteError(weighted=False, loss_weight=1.0, track_metrics=True)[source]
Mean Absolute Error function based on pixels.
After initialization the MeanAbsoluteError object can be used as a functor to calculate pixelwise mean absolute error of generated images:
mae = MeanAbsoluteError()
...
loss = mae(hr_batch, sr_batch, hr_critic, sr_critic)
If track_metrics is True, supplied metrics dictionaries will updated with calculated loss.
Parameters
  • weighted – whether loss should be weighted

  • loss_weight – weight factor for loss

  • track_metrics – whether the class should update the supplied metrics dictionaries.

__call__(hr_batch, sr_batch, hr_critic, sr_critic, batch_metrics, epoch_metrics)[source]

Calculate pixelwise mean absolute error for supplied batches of images.

Note

The parameters hr_critique and sr_critique will not be used/needed for calculation of mean absolute error, but the function needs to adhere to the (implicit) Generator loss function interface.

Parameters
  • hr_batch – Tensor of real data High-Resolution samples.

  • sr_batch – Tensor of synthesized High-Resolution samples with equal shape as hr_batch.

  • hr_critic – Not needed, may be None.

  • sr_critic – Not needed, may be None.

  • batch_metrics – Optional dictionary to store batch metrics.

  • epoch_metrics – Optional dictionary to store epoch metrics.

Returns

(Weighted) mean absolute error for batch.

class MeanSquaredError(weighted=False, loss_weight=1.0, track_metrics=True)[source]
Mean Squared Error function based on pixels.
After initialization the MeanSquaredError object can be used as a functor to calculate pixelwise mean squared error of generated images:
mse = MeanSquaredError()
...
loss = mse(hr_batch, sr_batch, hr_critic, sr_critic)
If track_metrics is True, supplied metrics dictionaries will updated with calculated loss.
Parameters
  • weighted – whether loss should be weighted

  • loss_weight – weight factor for loss

  • track_metrics – whether the class should update the supplied metrics dictionaries.

__call__(hr_batch, sr_batch, hr_critic, sr_critic, batch_metrics, epoch_metrics)[source]

Calculate pixelwise mean squared error for supplied batches of images.

Note

The parameters hr_critique and sr_critique will not be used/needed for calculation of mean squared error, but the function needs to adhere to the (implicit) Generator loss function interface.

Parameters
  • hr_batch – Tensor of real data High-Resolution samples.

  • sr_batch – Tensor of synthesized High-Resolution samples with equal shape as hr_batch.

  • hr_critic – Not needed, may be None.

  • sr_critic – Not needed, may be None.

  • batch_metrics – Optional dictionary to store batch metrics.

  • epoch_metrics – Optional dictionary to store epoch metrics.

Returns

(Weighted) mean squared error for batch.

class AdversarialLoss(weighted=False, loss_weight=1.0, track_metrics=True)[source]
Adversarial loss function for Generator in standard GAN setting.
After initialization the AdversarialLoss object can be used as a functor to calculate adversarial loss of the Generator:
adversarial_loss = AdversarialLoss()
...
loss = adversarial_loss(hr_batch, sr_batch, hr_critic, sr_critic)
If track_metrics is True, supplied metrics dictionaries will updated with calculated loss.
Parameters
  • weighted – whether loss should be weighted

  • loss_weight – weight factor for loss

  • track_metrics – whether the class should update the supplied metrics dictionaries.

__call__(hr_batch, sr_batch, hr_critic, sr_critic, batch_metrics, epoch_metrics)[source]

Calculate adversarial loss for Generator from discriminators critique.

Note

The parameters hr_batch, sr_batch and hr_critique will not be used/needed for calculation of adversarial loss, but the function needs to adhere to the (implicit) Generator loss function interface.

Parameters
  • hr_batch – Not needed, may be None

  • sr_batch – Not needed, may be None

  • hr_critic – Not needed, may be None

  • sr_critic – Discriminators critique of generated High-Resolution samples.

  • batch_metrics – Optional dictionary to store batch metrics.

  • epoch_metrics – Optional dictionary to store epoch metrics.

Returns

(Weighted) Adversarial loss for batch.

class RaAdversarialLoss(weighted=False, loss_weight=1.0, track_metrics=True)[source]
Relativistic average Adversarial loss function for Generator in relativistic GAN setting.
After initialization the RaAdversarialLoss object can be used as a functor to calculate adversarial loss of the Generator:
ra_loss = RaAdversarialLoss()
...
loss = ra_loss(hr_batch, sr_batch, hr_critic, sr_critic)
If track_metrics is True, supplied metrics dictionaries will updated with calculated loss.
Parameters
  • weighted – whether loss should be weighted

  • loss_weight – weight factor for loss

  • track_metrics – whether the class should update the supplied metrics dictionaries.

__call__(hr_batch, sr_batch, hr_critic, sr_critic, batch_metrics, epoch_metrics)[source]

Calculate relativistic average adversarial loss for Generator from discriminators critique.

Note

The parameters hr_batch and sr_batch will not be used/needed for calculation of relativistic average adversarial loss, but the function needs to adhere to the (implicit) Generator loss function interface.

Parameters
  • hr_batch – Not needed, may be None

  • sr_batch – Not needed, may be None

  • hr_critic – Discriminators critique of real data High-Resolution samples.

  • sr_critic – Discriminators critique of generated High-Resolution samples.

  • batch_metrics – Optional dictionary to store batch metrics.

  • epoch_metrics – Optional dictionary to store epoch metrics.

Returns

(Weighted) relativistic average adversarial loss for batch.

class VGGLoss(output_layers, feature_scale=1.0, loss_weight=1.0, total_variation_loss=False, total_varation_weight=2e-07, after_activation=True, track_metrics=True, vgg16=False, custom_weights=False, custom_weights_path=None)[source]
Loss function to calculate loss for Generator based on differences between activations in feature maps of a reference network. The reference network used here can be either the VGG19 (default) or VGG16 network.
Loss functions of this type are usually called Perceptual Loss and aim to drive the Generator towards producing visually more pleasing results, compared to pixelwise loss functions. See https://arxiv.org/abs/1609.04802 for more info on perceptual loss.
After initialization the VGGLoss object can be used as a functor for loss calculation:
vgg_loss = VGGLoss(output_layers="block5_conv4",
                   after_activation=False,
                   feature_scale=1.0,
                   loss_weight=1.0,
                   total_variation_loss=False)
...
loss = vgg_loss(hr_batch, sr_batch, hr_critic, sr_critic)
Parameters
  • output_layers

    Layers of Vgg19 network to compare feature maps of synthesized and real data samples on. May be a list of layers, the loss of each layer will be summed up.

  • feature_scale – Scaling factor for each feature map.

  • loss_weight

    Factor to weight the
    This can be useful if VGG loss is combined with an additional pixel-wise loss,
    which might be orders of magnitudes higher and could therefor outweigh VGG loss.

  • total_variation_loss

    Whether to use an additional total variation loss as used in some variants of SRGAN (https://arxiv.org/abs/1609.04802).

  • total_varation_weight – Factor to weight total variation loss.

  • after_activation

    Whether calculate loss on feature maps after activation function or before.
    See ESRGAN paper (https://arxiv.org/abs/1809.00219) for an explanation
    of the benefits and downsides.

  • track_metrics – whether the class should update the supplied metrics dictionaries.

  • vgg16 – Whether to use VGG16 instead of VGG19.

  • custom_weights – Whether to initialize VGG network with custom weights.

  • custom_weights_path – Path to custom weights file (.h5 file).

__call__(hr_batch, sr_batch, hr_critic, sr_critic, batch_metrics, epoch_metrics, denormalize=True)[source]

Calculate vgg loss for a batch of High-Resolution real data samples and synthesized High-Resolution samples.

Important

Pixel values for VGG need to be in [0, 255]. So you can either supply batches with pixels in that range, or if your pixels are in [-1, 1] you can use the denormalize flag for conversion. Any other combination will not work.

Note

The parameters hr_critique and sr_critique will not be used/needed for calculation of vgg loss, but the function needs to adhere to the (implicit) Generator loss function interface.

Parameters
  • hr_batch

    Tensor of real data High-Resolution samples.
    Pixel values either need to be in [0, 255] or [-1, 1] with denormalize=True.

  • sr_batch

    Tensor of synthesized High-Resolution samples with equal shape as hr_batch.
    Pixel values either need to be in [0, 255] or [-1, 1] with denormalize=True.

  • hr_critic – Not needed, may be None.

  • sr_critic – Not needed, may be None.

  • batch_metrics – Optional dictionary to store batch metrics.

  • epoch_metrics – Optional dictionary to store epoch metrics.

  • denormalize – Whether to denormalize from [-1, 1] to [0, 255].

Returns

(Weighted) vgg loss for batch.

Discriminator Loss Functions

class DiscriminatorLoss(weighted=False, loss_weight=1.0, track_metrics=True)[source]
Loss function for Discriminator in standard GAN setting.
After initialization the DiscriminatorLoss object can be used as a functor to calculate loss of the Discriminator:
discriminator_loss = DiscriminatorLoss()
...
loss = discriminator_loss(hr_batch, sr_batch, hr_critic, sr_critic)
If track_metrics is True, supplied metrics dictionaries will updated with calculated loss.
Parameters
  • weighted – whether loss should be weighted

  • loss_weight – weight factor for loss

  • track_metrics – whether the class should update the supplied metrics dictionaries.

__call__(sr_critic, hr_critic, sr_labels, hr_labels, batch_metrics, epoch_metrics)[source]

Calculate Discriminator loss based on real data samples and synthesized samples.

Parameters
  • sr_critic – Discriminators critique of synthesized High-Resolution samples from Generator.

  • hr_critic – Discriminators critique of corresponding real data High-Resolution samples.

  • sr_labels – Labels for synthesized samples to compare to Discriminators critique.

  • hr_labels – Labels for real data samples to compare to Discriminators critique.

  • batch_metrics – Optional dictionary to store batch metrics.

  • epoch_metrics – Optional dictionary to store epoch metrics.

Returns

(Weighted) Discriminator loss for batch.

class RaDiscriminatorLoss(weighted=False, loss_weight=1.0, track_metrics=True)[source]
Relativistic average loss function for Discriminator in relativistic GAN setting.
After initialization the RaDiscriminatorLoss object can be used as a functor to calculate loss of the Discriminator:
ra_loss = RaDiscriminatorLoss()
...
loss = ra_loss(hr_batch, sr_batch, hr_critic, sr_critic)
If track_metrics is True, supplied metrics dictionaries will updated with calculated loss.
Parameters
  • weighted – whether loss should be weighted

  • loss_weight – weight factor for loss

  • track_metrics – whether the class should update the supplied metrics dictionaries.

__call__(sr_critic, hr_critic, sr_labels, hr_labels, batch_metrics, epoch_metrics)[source]

Calculate relativistic average Discriminator loss based on real data samples and synthesized samples.

Parameters
  • sr_critic – Discriminators critique of synthesized High-Resolution samples from Generator.

  • hr_critic – Discriminators critique of corresponding real data High-Resolution samples.

  • sr_labels – Labels for synthesized samples to compare to Discriminators critique.

  • hr_labels – Labels for real data samples to compare to Discriminators critique.

  • batch_metrics – Optional dictionary to store batch metrics.

  • epoch_metrics – Optional dictionary to store epoch metrics.

Returns

(Weighted) relativistic average Discriminator loss for batch.