Skip to content

model_trainer

ModelTrainer

A class that is used to train the model. This class can train a Pytorch model with the given data loaders. The metric, loss_function, and model must be compatible with each other. Please see the details in the Attributes.

Attributes
  • device: A string. Indicating the device to use. 'cuda' or 'cpu'.

  • model: An instance of Pytorch Module. The model that will be trained.

  • train_loader: Training data wrapped in batches in Pytorch Dataloader.

  • test_loader: Testing data wrapped in batches in Pytorch Dataloader.

  • loss_function: A function with two parameters (prediction, target). There is no specific requirement for the types of the parameters, as long as they are compatible with the model and the data loaders. The prediction should be the output of the model for a batch. The target should be a batch of targets packed in the data loaders.

  • optimizer: The optimizer is chosen to use the Pytorch Adam optimizer.

  • early_stop: An instance of class EarlyStop.

  • metric: It should be a subclass of class autokeras.metric.Metric. In the compute(prediction, target) function, prediction and targets are all numpy arrays converted from the output of the model and the targets packed in the data loaders.

  • verbose: Verbosity mode.

init

Init the ModelTrainer with model, x_train, y_train, x_test, y_test, verbose

train_model

Train the model.

Args
  • max_iter_num: An integer. The maximum number of epochs to train the model. The training will stop when this number is reached.

  • max_no_improvement_num: An integer. The maximum number of epochs when the loss value doesn't decrease. The training will stop when this number is reached.

GANModelTrainer

init

Init the ModelTrainer with model, x_train, y_train, x_test, y_test, verbose