Source code for simple_sr.operations.training

import os
from tqdm import tqdm
import time
import logging
import tensorflow as tf
from simple_sr.operations import evaluation
from simple_sr.utils import logger
from simple_sr.utils.train_result import TrainResult

log = logging.getLogger(logger.RESULTS_LOGGER)


[docs]def run_training(config, data_pipeline, sr_model, plotting_interval=1): """ Run training session with initialized `DataPipeline` and `SRModel`. :param config: Initialized :code:`SimpleSR.ConfigUtil` object. :param data_pipeline: Initialized :code:`SimpleSR.DataPipeline` object. :param sr_model: Initialized :code:`SimpleSR.SRModel` object. :param plotting_interval: | Interval of epochs to plot and save images. | Every batch of the test set will be plotted (interpolated and upsampled by model respectively) and saved. Additionally one batch of training and validation set will be plotted as well. .. note:: Plotting can be helpful to monitor progress during training, but will impact performance and therefore lead to longer epoch times. """ _log_configurations(config, data_pipeline, sr_model) start_training = time.perf_counter() for epoch in tqdm(range(config.num_epochs)): learn_rate_info = \ f"optimizer step: {sr_model.generator_optimizer().iterations.numpy()}, "\ f"current learnrate: {sr_model.generator_optimizer()._decayed_lr(tf.float32).numpy()}" log.info(learn_rate_info) if sr_model.stop_early(): break sr_model.before_epoch() start_epoch = time.perf_counter() # training train_batch_generator = data_pipeline.train_batch_generator() for step, (lr_batch, hr_batch) in enumerate(train_batch_generator): sr_model.train_step(lr_batch, hr_batch) sr_model.after_train_batch() # validation validation_batch_generator = data_pipeline.validation_batch_generator() for step, (lr_batch, hr_batch) in enumerate(validation_batch_generator): sr_model.validation_step(lr_batch, hr_batch) sr_model.after_validation_batch() if epoch != 0 and epoch % plotting_interval == 0: # test and plot on one train set batch train_batch_generator = data_pipeline.train_batch_generator() for _lr_batch, _hr_batch in train_batch_generator.take(1): lr_batch = _lr_batch hr_batch = _hr_batch sr_model.test_and_plot( lr_batch, config.pic_dir_train, 0, hr_batch, ) # test and plot on one validation set batch validation_batch_generator = data_pipeline.validation_batch_generator() if type(validation_batch_generator) != list: for _lr_batch, _hr_batch in validation_batch_generator.take(1): lr_batch = _lr_batch hr_batch = _hr_batch sr_model.test_and_plot( lr_batch, config.pic_dir_val, 0, hr_batch, ) # test and plot on testing data try: test_batch_generator = data_pipeline.test_batch_generator(batch_size=config.batch_size) for step, (lr_batch, file_path) in enumerate(test_batch_generator): sr_model.test_and_plot( lr_batch, config.pic_dir_test, step, None, ) except tf.errors.InvalidArgumentError: test_batch_generator = data_pipeline.test_batch_generator(batch_size=1) for step, (lr_batch, file_path) in enumerate(test_batch_generator): sr_model.test_and_plot( lr_batch, config.pic_dir_test, step, None, ) epoch_duration = time.perf_counter() - start_epoch log_string = _prepare_log_string(sr_model, epoch, epoch_duration) log.info(log_string) sr_model.after_epoch() training_duration = time.perf_counter() - start_training log.info(f"finished training ({training_duration:.2f} sec)") sr_model.after_training() exp_res = TrainResult( sr_model.epoch_history(train=True), sr_model.epoch_history(train=False), sr_model.batch_history(train=True), sr_model.batch_history(train=False) ) os.makedirs(os.path.join(config.save_path, "json_dump"), exist_ok=True) exp_res.save_as_json(os.path.join(config.save_path, "json_dump")) log.info("calculating psnr/ssim on validation set") evaluation.evaluate_on_validationdata( config=config, model_name=sr_model.name, pipeline=data_pipeline, model=sr_model.generator() ) log.info("done")
def _log_configurations(config, data_pipeline, model_config): with open(config.config_logfile, "a+") as f: f.write("# Base config\n") f.write(f"{config}\n") f.write("# Data Pipeline config\n") f.write(f"{data_pipeline}\n") f.write("# Model config\n") f.write(f"{model_config}\n") def _prepare_log_string(sr_model, epoch, epoch_duration): log_string = f"epoch: {epoch} ({epoch_duration:.2f} sec)\n" log_string += sr_model.formatted_epoch_metrics() return log_string