msp.rl_algorithm package¶
Module contents¶
The mps.rl_algorithm module defines reinforcement learning
algorithm to train the model.
- class msp.rl_algorithm.ReinforceAlgorithm(objective_func, optimizer, train_metric, val_metric, tol=0.001)[source]¶
Bases:
object- apply_gradient(model_train, model_baseline, batch_train)[source]¶
Apply the gradients to the trainable model weights.
- Parameters
model_train – model for the training
model_baseline – model baseline for comparison as per REINFORCE ALOGRITHM
batch_train – input mini batch data
- evaluate(model_train, model_baseline, val_dataset, verbose=1, callbacks=None)[source]¶
Perform model evaluation.
- Parameters
model_train – model for the training
model_baseline – model baseline for comparison as per REINFORCE ALOGRITHM
val_dataset – dataset for validation
verbose – Verbosity mode. 0 = silent, 1 = progress bar.
callbacks – List of callbacks to apply during evaluation.
- run(model_train, model_baseline, train_dataset, val_dataset, epochs=1, verbose=1, callbacks=None, epoch_when_last_save=0)[source]¶
Run the REINFORCE Algorithm on model_train and baseline.
- Parameters
model_train – model for the training
model_baseline – model baseline for comparison as per REINFORCE ALOGRITHM
train_dataset – dataset for training
val_dataset – dataset for validation
epochs – number of epochs
verbose – Verbosity mode. 0 = silent, 1 = progress bar.
callbacks – List of callbacks to apply during training.
epoch_when_last_save – Used when retraining a model
- train_model_for_one_epoch(model_train, model_baseline, train_dataset, callbacks, verbose=1)[source]¶
Computes the loss then updates the weights and metrics for one epoch.
- Parameters
model_train – model for the training
model_baseline – model baseline for comparison as per REINFORCE ALOGRITHM
train_dataset – dataset for training
callbacks – List of callbacks to apply during training.
verbose – Verbosity mode. 0 = silent, 1 = progress bar.