msp.rl_algorithm package

Module contents

The mps.rl_algorithm module defines reinforcement learning algorithm to train the model.

class msp.rl_algorithm.ReinforceAlgorithm(objective_func, optimizer, train_metric, val_metric, tol=0.001)[source]

Bases: object

apply_gradient(model_train, model_baseline, batch_train)[source]

Apply the gradients to the trainable model weights.

Parameters
  • model_train – model for the training

  • model_baseline – model baseline for comparison as per REINFORCE ALOGRITHM

  • batch_train – input mini batch data

evaluate(model_train, model_baseline, val_dataset, verbose=1, callbacks=None)[source]

Perform model evaluation.

Parameters
  • model_train – model for the training

  • model_baseline – model baseline for comparison as per REINFORCE ALOGRITHM

  • val_dataset – dataset for validation

  • verbose – Verbosity mode. 0 = silent, 1 = progress bar.

  • callbacks – List of callbacks to apply during evaluation.

run(model_train, model_baseline, train_dataset, val_dataset, epochs=1, verbose=1, callbacks=None, epoch_when_last_save=0)[source]

Run the REINFORCE Algorithm on model_train and baseline.

Parameters
  • model_train – model for the training

  • model_baseline – model baseline for comparison as per REINFORCE ALOGRITHM

  • train_dataset – dataset for training

  • val_dataset – dataset for validation

  • epochs – number of epochs

  • verbose – Verbosity mode. 0 = silent, 1 = progress bar.

  • callbacks – List of callbacks to apply during training.

  • epoch_when_last_save – Used when retraining a model

train_model_for_one_epoch(model_train, model_baseline, train_dataset, callbacks, verbose=1)[source]

Computes the loss then updates the weights and metrics for one epoch.

Parameters
  • model_train – model for the training

  • model_baseline – model baseline for comparison as per REINFORCE ALOGRITHM

  • train_dataset – dataset for training

  • callbacks – List of callbacks to apply during training.

  • verbose – Verbosity mode. 0 = silent, 1 = progress bar.