Skip to content

model_trainer

ModelTrainerBase

A base class all model trainers will inherit from.

Attributes
  • device: A string. Indicating the device to use. 'cuda' or 'cpu'.

  • train_loader: Training data wrapped in batches in Pytorch Dataloader.

  • test_loader: Testing data wrapped in batches in Pytorch Dataloader.

  • loss_function: A function with two parameters (prediction, target). There is no specific requirement for the types of the parameters, as long as they are compatible with the model and the data loaders. The prediction should be the output of the model for a batch. The target should be a batch of targets packed in the data loaders.

  • metric: It should be a subclass of class autokeras.metric.Metric. In the compute(prediction, target) function, prediction and targets are, all numpy arrays converted from the output of the model and the targets packed in the data loaders.

  • verbose: Verbosity mode.

train_model

Train the model.

Args
  • max_iter_num: int, maximum numer of iteration

  • max_no_improvement_num: after max_no_improvement_num, if the model still makes no improvement, finish training.

ModelTrainer

A class that is used to train the model. This class can train a Pytorch model with the given data loaders. The metric, loss_function, and model must be compatible with each other. Please see the details in the Attributes.

Attributes
  • temp_model_path: Specify the path where temp model should be stored.

  • model: An instance of Pytorch Module. The model that will be trained.

  • early_stop: An instance of class EarlyStop.

  • optimizer: The optimizer is chosen to use the Pytorch Adam optimizer.

  • current_epoch: Record the current epoch.

train_model

Train the model. Train the model with max_iter_num or max_no_improvement_num is met.

Args
  • max_iter_num: An integer. The maximum number of epochs to train the model. The training will stop when this number is reached.

  • max_no_improvement_num: An integer. The maximum number of epochs when the loss value doesn't decrease. The training will stop when this number is reached.

Returns

_train

Where the actual train proceed.

_test

Function for evaluation.

GANModelTrainer

A ModelTrainer especially for the GAN.

Attributes
  • d_model: A discriminator model.

  • g_model: A generator model.

  • out_f: Out file.

  • out_size: Size of the output image.

  • optimizer_d: Optimizer for discriminator.

  • optimizer_g: Optimizer for generator.

init

Initialize the GANModelTrainer.

Args
  • g_model: The generator model to be trained.

  • d_model: The discriminator model to be trained.

  • train_data: the training data.

  • loss_function: The loss function for both discriminator and generator.

  • verbose: Whether to output the system output.

  • gen_training_result: Whether to generate the intermediate result while training.

_train

Perform the actual train.

EarlyStop

A class check for early stop condition.

Attributes
  • training_losses: Record all the training loss.

  • minimum_loss: The minimum loss we achieve so far. Used to compared to determine no improvement condition.

  • no_improvement_count: Current no improvement count.

  • _max_no_improvement_num: The maximum number specified.

  • _done: Whether condition met.

  • _min_loss_dec: A threshold for loss improvement.

on_train_begin

Initiate the early stop condition. Call on every time the training iteration begins.

on_epoch_end

Check the early stop condition. Call on every time the training iteration end.

Args
  • loss: The loss function achieved by the epoch.
Returns