Loss Functions¶
Generator Loss Functions¶
-
class
MeanAbsoluteError(weighted=False, loss_weight=1.0, track_metrics=True)[source]¶ - Mean Absolute Error function based on pixels.After initialization the MeanAbsoluteError object can be used as a functor to calculate pixelwise mean absolute error of generated images:
mae = MeanAbsoluteError() ... loss = mae(hr_batch, sr_batch, hr_critic, sr_critic)
If track_metrics is True, supplied metrics dictionaries will updated with calculated loss.- Parameters
weighted – whether loss should be weighted
loss_weight – weight factor for loss
track_metrics – whether the class should update the supplied metrics dictionaries.
-
__call__(hr_batch, sr_batch, hr_critic, sr_critic, batch_metrics, epoch_metrics)[source]¶ Calculate pixelwise mean absolute error for supplied batches of images.
Note
The parameters hr_critique and sr_critique will not be used/needed for calculation of mean absolute error, but the function needs to adhere to the (implicit) Generator loss function interface.
- Parameters
hr_batch – Tensor of real data High-Resolution samples.
sr_batch – Tensor of synthesized High-Resolution samples with equal shape as hr_batch.
hr_critic – Not needed, may be None.
sr_critic – Not needed, may be None.
batch_metrics – Optional dictionary to store batch metrics.
epoch_metrics – Optional dictionary to store epoch metrics.
- Returns
(Weighted) mean absolute error for batch.
-
class
MeanSquaredError(weighted=False, loss_weight=1.0, track_metrics=True)[source]¶ - Mean Squared Error function based on pixels.After initialization the MeanSquaredError object can be used as a functor to calculate pixelwise mean squared error of generated images:
mse = MeanSquaredError() ... loss = mse(hr_batch, sr_batch, hr_critic, sr_critic)
If track_metrics is True, supplied metrics dictionaries will updated with calculated loss.- Parameters
weighted – whether loss should be weighted
loss_weight – weight factor for loss
track_metrics – whether the class should update the supplied metrics dictionaries.
-
__call__(hr_batch, sr_batch, hr_critic, sr_critic, batch_metrics, epoch_metrics)[source]¶ Calculate pixelwise mean squared error for supplied batches of images.
Note
The parameters hr_critique and sr_critique will not be used/needed for calculation of mean squared error, but the function needs to adhere to the (implicit) Generator loss function interface.
- Parameters
hr_batch – Tensor of real data High-Resolution samples.
sr_batch – Tensor of synthesized High-Resolution samples with equal shape as hr_batch.
hr_critic – Not needed, may be None.
sr_critic – Not needed, may be None.
batch_metrics – Optional dictionary to store batch metrics.
epoch_metrics – Optional dictionary to store epoch metrics.
- Returns
(Weighted) mean squared error for batch.
-
class
AdversarialLoss(weighted=False, loss_weight=1.0, track_metrics=True)[source]¶ - Adversarial loss function for Generator in standard GAN setting.After initialization the AdversarialLoss object can be used as a functor to calculate adversarial loss of the Generator:
adversarial_loss = AdversarialLoss() ... loss = adversarial_loss(hr_batch, sr_batch, hr_critic, sr_critic)
If track_metrics is True, supplied metrics dictionaries will updated with calculated loss.- Parameters
weighted – whether loss should be weighted
loss_weight – weight factor for loss
track_metrics – whether the class should update the supplied metrics dictionaries.
-
__call__(hr_batch, sr_batch, hr_critic, sr_critic, batch_metrics, epoch_metrics)[source]¶ Calculate adversarial loss for Generator from discriminators critique.
Note
The parameters hr_batch, sr_batch and hr_critique will not be used/needed for calculation of adversarial loss, but the function needs to adhere to the (implicit) Generator loss function interface.
- Parameters
hr_batch – Not needed, may be None
sr_batch – Not needed, may be None
hr_critic – Not needed, may be None
sr_critic – Discriminators critique of generated High-Resolution samples.
batch_metrics – Optional dictionary to store batch metrics.
epoch_metrics – Optional dictionary to store epoch metrics.
- Returns
(Weighted) Adversarial loss for batch.
-
class
RaAdversarialLoss(weighted=False, loss_weight=1.0, track_metrics=True)[source]¶ - Relativistic average Adversarial loss function for Generator in relativistic GAN setting.After initialization the RaAdversarialLoss object can be used as a functor to calculate adversarial loss of the Generator:
ra_loss = RaAdversarialLoss() ... loss = ra_loss(hr_batch, sr_batch, hr_critic, sr_critic)
If track_metrics is True, supplied metrics dictionaries will updated with calculated loss.- Parameters
weighted – whether loss should be weighted
loss_weight – weight factor for loss
track_metrics – whether the class should update the supplied metrics dictionaries.
-
__call__(hr_batch, sr_batch, hr_critic, sr_critic, batch_metrics, epoch_metrics)[source]¶ Calculate relativistic average adversarial loss for Generator from discriminators critique.
Note
The parameters hr_batch and sr_batch will not be used/needed for calculation of relativistic average adversarial loss, but the function needs to adhere to the (implicit) Generator loss function interface.
- Parameters
hr_batch – Not needed, may be None
sr_batch – Not needed, may be None
hr_critic – Discriminators critique of real data High-Resolution samples.
sr_critic – Discriminators critique of generated High-Resolution samples.
batch_metrics – Optional dictionary to store batch metrics.
epoch_metrics – Optional dictionary to store epoch metrics.
- Returns
(Weighted) relativistic average adversarial loss for batch.
-
class
VGGLoss(output_layers, feature_scale=1.0, loss_weight=1.0, total_variation_loss=False, total_varation_weight=2e-07, after_activation=True, track_metrics=True, vgg16=False, custom_weights=False, custom_weights_path=None)[source]¶ - Loss function to calculate loss for Generator based on differences between activations in feature maps of a reference network. The reference network used here can be either the VGG19 (default) or VGG16 network.Loss functions of this type are usually called Perceptual Loss and aim to drive the Generator towards producing visually more pleasing results, compared to pixelwise loss functions. See https://arxiv.org/abs/1609.04802 for more info on perceptual loss.After initialization the VGGLoss object can be used as a functor for loss calculation:
vgg_loss = VGGLoss(output_layers="block5_conv4", after_activation=False, feature_scale=1.0, loss_weight=1.0, total_variation_loss=False) ... loss = vgg_loss(hr_batch, sr_batch, hr_critic, sr_critic)
- Parameters
output_layers –
Layers of Vgg19 network to compare feature maps of synthesized and real data samples on. May be a list of layers, the loss of each layer will be summed up.The names of each layer can be found here: https://github.com/keras-team/keras-applications/blob/master/keras_applications/vgg19.pyfeature_scale – Scaling factor for each feature map.
loss_weight –
Factor to weight theThis can be useful if VGG loss is combined with an additional pixel-wise loss,which might be orders of magnitudes higher and could therefor outweigh VGG loss.total_variation_loss –
Whether to use an additional total variation loss as used in some variants of SRGAN (https://arxiv.org/abs/1609.04802).total_varation_weight – Factor to weight total variation loss.
after_activation –
Whether calculate loss on feature maps after activation function or before.See ESRGAN paper (https://arxiv.org/abs/1809.00219) for an explanationof the benefits and downsides.track_metrics – whether the class should update the supplied metrics dictionaries.
vgg16 – Whether to use VGG16 instead of VGG19.
custom_weights – Whether to initialize VGG network with custom weights.
custom_weights_path – Path to custom weights file (.h5 file).
-
__call__(hr_batch, sr_batch, hr_critic, sr_critic, batch_metrics, epoch_metrics, denormalize=True)[source]¶ Calculate vgg loss for a batch of High-Resolution real data samples and synthesized High-Resolution samples.
Important
Pixel values for VGG need to be in [0, 255]. So you can either supply batches with pixels in that range, or if your pixels are in [-1, 1] you can use the denormalize flag for conversion. Any other combination will not work.
Note
The parameters hr_critique and sr_critique will not be used/needed for calculation of vgg loss, but the function needs to adhere to the (implicit) Generator loss function interface.
- Parameters
hr_batch –
Tensor of real data High-Resolution samples.Pixel values either need to be in [0, 255] or [-1, 1] with denormalize=True.sr_batch –
Tensor of synthesized High-Resolution samples with equal shape as hr_batch.Pixel values either need to be in [0, 255] or [-1, 1] with denormalize=True.hr_critic – Not needed, may be None.
sr_critic – Not needed, may be None.
batch_metrics – Optional dictionary to store batch metrics.
epoch_metrics – Optional dictionary to store epoch metrics.
denormalize – Whether to denormalize from [-1, 1] to [0, 255].
- Returns
(Weighted) vgg loss for batch.
Discriminator Loss Functions¶
-
class
DiscriminatorLoss(weighted=False, loss_weight=1.0, track_metrics=True)[source]¶ - Loss function for Discriminator in standard GAN setting.After initialization the DiscriminatorLoss object can be used as a functor to calculate loss of the Discriminator:
discriminator_loss = DiscriminatorLoss() ... loss = discriminator_loss(hr_batch, sr_batch, hr_critic, sr_critic)
If track_metrics is True, supplied metrics dictionaries will updated with calculated loss.- Parameters
weighted – whether loss should be weighted
loss_weight – weight factor for loss
track_metrics – whether the class should update the supplied metrics dictionaries.
-
__call__(sr_critic, hr_critic, sr_labels, hr_labels, batch_metrics, epoch_metrics)[source]¶ Calculate Discriminator loss based on real data samples and synthesized samples.
- Parameters
sr_critic – Discriminators critique of synthesized High-Resolution samples from Generator.
hr_critic – Discriminators critique of corresponding real data High-Resolution samples.
sr_labels – Labels for synthesized samples to compare to Discriminators critique.
hr_labels – Labels for real data samples to compare to Discriminators critique.
batch_metrics – Optional dictionary to store batch metrics.
epoch_metrics – Optional dictionary to store epoch metrics.
- Returns
(Weighted) Discriminator loss for batch.
-
class
RaDiscriminatorLoss(weighted=False, loss_weight=1.0, track_metrics=True)[source]¶ - Relativistic average loss function for Discriminator in relativistic GAN setting.After initialization the RaDiscriminatorLoss object can be used as a functor to calculate loss of the Discriminator:
ra_loss = RaDiscriminatorLoss() ... loss = ra_loss(hr_batch, sr_batch, hr_critic, sr_critic)
If track_metrics is True, supplied metrics dictionaries will updated with calculated loss.- Parameters
weighted – whether loss should be weighted
loss_weight – weight factor for loss
track_metrics – whether the class should update the supplied metrics dictionaries.
-
__call__(sr_critic, hr_critic, sr_labels, hr_labels, batch_metrics, epoch_metrics)[source]¶ Calculate relativistic average Discriminator loss based on real data samples and synthesized samples.
- Parameters
sr_critic – Discriminators critique of synthesized High-Resolution samples from Generator.
hr_critic – Discriminators critique of corresponding real data High-Resolution samples.
sr_labels – Labels for synthesized samples to compare to Discriminators critique.
hr_labels – Labels for real data samples to compare to Discriminators critique.
batch_metrics – Optional dictionary to store batch metrics.
epoch_metrics – Optional dictionary to store epoch metrics.
- Returns
(Weighted) relativistic average Discriminator loss for batch.