"""
The :mod:`mps.utils._state` module defines RL-Environment state.
"""
import copy
import tensorflow as tf
[docs]class MSPState(tf.Module):
def __init__(self):
self.is_build = False
[docs] def build(self, input_shape):
"""Create variables on first call."""
self.input_shape = input_shape
batch_size, num_node, num_machine = self.input_shape.job_assignment
with tf.name_scope('tracking_vars'):
# B x 1
self._first_node = tf.Variable(
initial_value=tf.zeros((batch_size, 1), dtype=tf.int64),
trainable=False,
name='first_node'
)
# B x 1
self._last_node = tf.Variable(
initial_value=tf.zeros((batch_size, 1), dtype=tf.int64),
trainable=False,
name='last_node'
)
# B x V x 1
self._visited_t = tf.Variable(
initial_value=tf.zeros((batch_size, num_node, 1)),
trainable=False,
name='visited_nodes'
)
# B x n_machines x 1
self._visited_mt = tf.Variable(
initial_value=tf.zeros((batch_size, num_machine, 1)),
trainable=False,
name='generated_job_subsequence_for_machine'
)
# B x 1
self.mrg_machine = tf.Variable(
initial_value=tf.zeros((batch_size,1), dtype=tf.int64),
trainable=False,
name='most_recently_generated_machine'
)
# B x 1 x V
self.mask = tf.Variable(
initial_value=self._create_initial_mask(input_shape),
trainable=False,
name='mask'
)
self.step_count = tf.Variable(
initial_value=1,
trainable=False,
name='decoding_step_counter'
)
self.is_build = True
@property
def first_node(self):
return self._first_node
@property
def last_node(self):
return self._last_node
[docs] def get_mask(self):
return self.mask
[docs] def get_step_count(self):
return self.step_count
def __call__(self, inputs, selected_node):
"""Call updates the current state."""
if not self.is_built:
self.build(tf.shape(inputs))
self.update(inputs, selected_node)
[docs] def update(self, inputs, selected_node):
""" """
if not self.is_build:
self.build(tf.shape(inputs))
batch_size = tf.shape(selected_node)[0]
num_job = inputs.num_job
node_type = inputs.node_features[:, :, -1]
# B x 1
type_of_selected_node = tf.cast(
tf.gather(node_type, selected_node, batch_dims=1),
dtype=tf.int64
)
# Update first node
is_first_step = tf.equal(self.get_step_count(), 1)
true_fn = lambda: selected_node
false_fn = lambda: self._first_node
self._first_node.assign(tf.cond(is_first_step, true_fn, false_fn))
# Update last node
self._last_node.assign(selected_node)
# Update visited_mt
is_not_first_step = tf.greater(self.get_step_count(), 1)
def true_fn():
temp = tf.zeros(tf.shape(self._visited_t))
temp = tf.tensor_scatter_nd_update(
tf.squeeze(temp, axis=-1),
tf.concat([
tf.range(batch_size, dtype=tf.int64)[:, tf.newaxis],
self.mrg_machine * (1 - type_of_selected_node)
], axis=1),
tf.ones((batch_size,), dtype=temp.dtype)
)
return temp[:, num_job:, tf.newaxis]
false_fn = lambda : self._visited_mt
self._visited_mt.assign_add(tf.cond(is_not_first_step, true_fn, false_fn))
# Update visited nodes
self._visited_t.scatter_nd_update(
tf.concat([
tf.reshape(tf.range(batch_size, dtype=tf.int64), tf.shape(selected_node)),
selected_node,
tf.reshape(tf.zeros(batch_size, dtype=tf.int64), tf.shape(selected_node)),
], axis=1),
tf.ones((batch_size,))
)
# Update most recently generated machine (mrg)
self.mrg_machine.assign(
tf.where(
tf.cast(1 - type_of_selected_node, dtype=tf.bool),
selected_node,
self.mrg_machine
)
)
# Update mask
is_still_decoding = self.step_count < inputs.num_node
true_fn = lambda: self._compute_mask(inputs)
false_fn = lambda: self.mask
self.mask.assign(tf.cond(is_still_decoding, true_fn, false_fn))
# Update step count
self.step_count.assign_add(1)
[docs] def reset(self):
assert self.is_build, "build the state module...."
self._first_node.assign(tf.zeros(tf.shape(self._first_node), dtype=tf.int64))
self._last_node.assign(tf.zeros(tf.shape(self._last_node), dtype=tf.int64))
self._visited_t.assign(tf.zeros(tf.shape(self._visited_t)))
self._visited_mt.assign(tf.zeros(tf.shape(self._visited_mt)))
self.mrg_machine.assign(tf.zeros(tf.shape(self.mrg_machine), dtype=tf.int64))
self.mask.assign(self._create_initial_mask(self.input_shape))
self.step_count.assign(1)
# ##########################################################################
# ...........................PRIVATE METHODS................................
# ##########################################################################
def _create_initial_mask(self, input_shape):
batch_size, num_node, num_machine = input_shape.job_assignment
large_negative_constant = tf.negative(1e10)
# At timestep t=1
job_type = tf.ones((batch_size, 1, num_node-num_machine), dtype=tf.float32)
machine_type = tf.zeros((batch_size, 1, num_machine), dtype=tf.float32)
node_type = tf.concat([job_type, machine_type], axis=-1)
return tf.negative(1e10)*node_type
def _compute_mask(self, inputs):
""" """
batch_size, num_node, num_machine = inputs.job_assignment.shape
large_negative_constant = tf.negative(1e10)
# B x 1 x V
node_type = tf.transpose(inputs.node_features[:, :, -1:], perm=[0, 2, 1])
mask_D_t = self._deadlock_prevention_mask(inputs, node_type)
mask_N_t = self._eligible_neighbor_nodes(inputs, node_type)
mask_V_t = self._unvisited_mask()
mask_t_prime = mask_D_t * mask_N_t * mask_V_t
# B x 1 x V
mask_t = large_negative_constant * (1 - mask_t_prime)
is_mask_invalid = tf.reduce_any(
tf.reduce_all(tf.equal(mask_t, large_negative_constant), axis=-1))
tf.assert_equal(is_mask_invalid, False, message='invalid_mask')
return mask_t
def _deadlock_prevention_mask(self, inputs, node_type):
""" """
# B x 1 x 1
is_deadlock = self._is_deadlock(inputs)
# B x 1 x V
mask_D_t = tf.where(is_deadlock, node_type, tf.ones(tf.shape(node_type)))
return mask_D_t
def _eligible_neighbor_nodes(self, inputs, node_type):
""" """
adj_matrix = inputs.adj_matrix
job_assignment = inputs.job_assignment
num_job = inputs.num_job
# B x 1
mrg_machine_prime = tf.math.mod(self.mrg_machine, num_job)
# B x 1 x V
neighbours = tf.gather_nd(
adj_matrix,
indices=self.last_node[:,:,tf.newaxis],
batch_dims=1
)
# B x 1 x V
mask_N_t = tf.multiply(
neighbours,
tf.add(
1 - node_type,
tf.gather(
tf.transpose(job_assignment, perm=[0, 2, 1]),
indices=mrg_machine_prime,
batch_dims=1
)
)
)
return mask_N_t
def _unvisited_mask(self):
""" """
# B x 1 x V
mask_V_t = 1 - tf.transpose(self._visited_t, perm=[0, 2, 1])
return mask_V_t
def _is_deadlock(self, inputs):
""" """
alpha = inputs.job_assignment
num_job = inputs.num_job
# B x n_machines x V
alpha_prime = tf.transpose(alpha, perm=[0, 2, 1]) * (1 - self._visited_mt)
# B x 1
mrg_machine_prime = tf.math.mod(self.mrg_machine, num_job)
# B x 1 x V
alpha_j_prime = tf.gather(alpha_prime, indices=mrg_machine_prime,
batch_dims=1)
# B x 1 x 1
delta_t = tf.matmul(
tf.cast(
tf.greater(
alpha_j_prime,
tf.subtract(
tf.reduce_sum(alpha_prime, axis=-2, keepdims=True),
alpha_j_prime
)
),
dtype=tf.float32
),
1 - self._visited_t
)
is_dead_lock = tf.greater(delta_t, 0)
return is_dead_lock