"""
The :mod:`mps.rl_algorithm._reinforce` module defines reinforcement learning
algorithm to train the model.
Reference:
Kool, Wouter, H. V. Hoof and M. Welling.
“Attention, Learn to Solve Routing Problems!” ICLR (2019).
"""
import time
import numpy as np
import tensorflow as tf
[docs]class ReinforceAlgorithm:
def __init__(self, objective_func, optimizer, train_metric, val_metric,
tol=1e-3):
"""REINFORCE Algorithm to train the model.
Args:
objective_func: objective function of the problem
optimizer: optimizer to update model weights
train_metric: metric to record model performance on training dataset
val_metric:metric to record model performance on validation dataset
tol: tolerance to update the baseline model
"""
self.optimizer = optimizer
self.objective_func = objective_func
self.train_metric = train_metric
self.val_metric = val_metric
self.tol = tol
[docs] def run(self, model_train, model_baseline, train_dataset, val_dataset,
epochs=1, verbose=1, callbacks=None, epoch_when_last_save=0):
"""Run the REINFORCE Algorithm on model_train and baseline.
Args:
model_train: model for the training
model_baseline: model baseline for comparison as per `REINFORCE ALOGRITHM`
train_dataset: dataset for training
val_dataset: dataset for validation
epochs: number of epochs
verbose: Verbosity mode. 0 = silent, 1 = progress bar.
callbacks: List of callbacks to apply during training.
epoch_when_last_save: Used when retraining a model
"""
# Container that configures and calls `tf.keras.Callback`s.
baseline_callbacks = tf.keras.callbacks.CallbackList(
callbacks[-1],
add_history=True,
add_progbar=False,
model=model_baseline,
verbose=verbose,
epochs=epochs)
callbacks = tf.keras.callbacks.CallbackList(
callbacks[:-1],
add_history=True,
add_progbar=verbose != 0,
model=model_train,
verbose=verbose,
epochs=epochs,
steps=len(list(enumerate(train_dataset))))
callbacks.on_train_begin()
baseline_callbacks.on_train_begin()
train_cum_time, val_cum_time = 0, 0
for epoch in range(epoch_when_last_save, epochs):
# print('Epoch {}/{}'.format(epoch+1, epochs))
logs = {}
callbacks.on_epoch_begin(epoch)
baseline_callbacks.on_epoch_begin(epoch)
############## Model Training ##############
start_time = time.perf_counter()
train_losses = self.train_model_for_one_epoch(
model_train,
model_baseline,
train_dataset,
callbacks,
verbose=verbose
)
train_metric_result = self.train_metric.result()
train_time = time.perf_counter() - start_time
train_cum_time += train_time
############## Model Validation ##############
start_time = time.perf_counter()
val_losses = self.evaluate(model_train, model_baseline, val_dataset)
val_metric_result = self.val_metric.result()
val_time = time.perf_counter() - start_time
val_cum_time += val_time
# update baseline model if train model is better
update_baseline = val_metric_result['train'] + self.tol < val_metric_result['baseline']
if update_baseline:
model_baseline.set_weights(model_train.get_weights())
logs['train_time'] = train_time
logs['cum_train_time'] = train_cum_time
logs['train_loss'] = np.mean(train_losses)
logs['makespan_train_data_by_train_model'] = train_metric_result['train']
logs['makespan_train_data_by_baseline_model'] = train_metric_result['baseline']
logs['val_time'] = val_time
logs['cum_val_time'] = val_cum_time
logs['val_loss'] = np.mean(val_losses)
logs['makespan_val_data_by_train_model'] = val_metric_result['train']
logs['makespan_val_data_by_baseline_model'] = val_metric_result['baseline']
logs['is_baseline_updated'] = update_baseline
# Reset states of all metrics
self.train_metric.reset_states()
self.val_metric.reset_states()
baseline_callbacks.on_epoch_end(epoch, logs)
callbacks.on_epoch_end(epoch, logs)
baseline_callbacks.on_train_end()
callbacks.on_train_end()
[docs] def train_model_for_one_epoch(self, model_train, model_baseline,
train_dataset, callbacks, verbose=1):
"""Computes the loss then updates the weights and metrics for one epoch.
Args:
model_train: model for the training
model_baseline: model baseline for comparison as per `REINFORCE ALOGRITHM`
train_dataset: dataset for training
callbacks: List of callbacks to apply during training.
verbose: Verbosity mode. 0 = silent, 1 = progress bar.
"""
losses = []
# Iterate over all the batches of the dataset.
for step, batch_train in enumerate(train_dataset):
callbacks.on_train_batch_begin(step)
objective_train, objective_baseline, loss_value = self.apply_gradient(
model_train,
model_baseline,
batch_train
)
losses.append(loss_value)
self.train_metric.update_state(objective_train, objective_baseline)
logs_ = {
'train_loss': float(loss_value),
'makespan_train': self.train_metric.result()['train'],
'makespan_baseline': self.train_metric.result()['baseline']
}
callbacks.on_train_batch_end(step+1, logs=logs_)
return losses
[docs] def apply_gradient(self, model_train, model_baseline, batch_train):
"""Apply the gradients to the trainable model weights.
Args:
model_train: model for the training
model_baseline: model baseline for comparison as per `REINFORCE ALOGRITHM`
batch_train: input mini batch data
"""
# with no gradient
schedules_baseline, _ = model_baseline(batch_train, training=False)
# Open a GradientTape to record the operations run
# during the forward pass, which enables auto-differentiation.
with tf.GradientTape() as tape:
# B x V x 2, B x 1
schedules_train, sum_log_probs = model_train(batch_train, training=True)
# compute problem objective
objective_train = self.objective_func(batch_train, schedules_train)
objective_baseline = self.objective_func(batch_train, schedules_baseline)
# Compute the loss value for this minibatch.
loss_value = tf.reduce_mean((objective_train - objective_baseline) * sum_log_probs)
# Use the gradient tape to automatically retrieve
# the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(loss_value, model_train.trainable_weights)
# Run one step of gradient descent by updating
# the value of the variables to minimize the loss.
self.optimizer.apply_gradients(zip(grads, model_train.trainable_weights))
return objective_train, objective_baseline, loss_value
[docs] def evaluate(self, model_train, model_baseline, val_dataset, verbose=1,
callbacks=None):
"""Perform model evaluation.
Args:
model_train: model for the training
model_baseline: model baseline for comparison as per `REINFORCE ALOGRITHM`
val_dataset: dataset for validation
verbose: Verbosity mode. 0 = silent, 1 = progress bar.
callbacks: List of callbacks to apply during evaluation.
"""
# Container that configures and calls `tf.keras.Callback`s.
if not isinstance(callbacks, tf.keras.callbacks.CallbackList):
callbacks = tf.keras.callbacks.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
model=model_train,
verbose=verbose,
epochs=1,
steps=len(list(enumerate(val_dataset))))
losses = []
callbacks.on_test_begin()
# Iterate over the batches of the dataset.
for step, batch_val in enumerate(val_dataset):
callbacks.on_test_batch_begin(step)
t = time.perf_counter()
schedules_train, sum_log_probs = model_train(batch_val, training=False)
schedules_baseline, _ = model_baseline(batch_val, training=False)
# compute problem objective
objective_train = self.objective_func(batch_val, schedules_train)
objective_baseline = self.objective_func(batch_val, schedules_baseline)
# Compute the loss value for this minibatch.
loss_value = tf.reduce_mean((objective_train - objective_baseline) * sum_log_probs)
losses.append(loss_value)
# update validation metrics
self.val_metric.update_state(objective_train, objective_baseline)
logs_ = {
'val_loss': float(loss_value),
'makespan_train': self.val_metric.result()['train'],
'makespan_baseline': self.val_metric.result()['baseline']
}
callbacks.on_test_batch_end(step+1, logs=logs_)
callbacks.on_test_end()
return losses