Source code for msp.models.decoders._attention

"""
The :mod:`mps.models._attention` module defines attention based 
decoder architecture.
"""
from collections import namedtuple

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import Model

from msp.layers import ContextEmbedding, MHALayer, SHALayer
from msp.utils import MSPEnv


[docs]class AttentionDecoder(Model): def __init__(self, units, *args, use_bias=False, n_heads=8, aggregation_graph='mean', tanh_clipping=10, **kwargs): super(AttentionDecoder, self).__init__(*args, **kwargs) self.aggregation_graph = aggregation_graph self.n_heads = n_heads self.tanh_clipping = tanh_clipping assert units % n_heads == 0, \ "Number of heads should be multiple of hidden dimensions." self.context_embed = ContextEmbedding( units, use_bias=use_bias, name='context_embed_layer') self.mha_layer = MHALayer( n_heads, units//n_heads, use_bias=use_bias, name='mha_layer') self.sha_layer = SHALayer(units, tanh_clipping=tanh_clipping) self.msp_env = MSPEnv() ############# For pre-computations #################### # 1. projection for graph embedding to build fixed context self.W_CG = tf.keras.layers.Dense( units, use_bias=use_bias, name='precomputation_W_CG') # 2. projections for MHA (keys and values, KV) and SHA (keys, K) self.W_KVK = tf.keras.layers.Dense( 3*units, use_bias=use_bias, name='precomputation_W_KVK') self.is_build = False
[docs] def call(self, inputs, training=None): """ """ if not self.is_build: self.msp_env.build(inputs.shape) self.is_build = True schedules, log_probs = [], [] time_step = self.msp_env.reset() # pre-computations (i.e., independent of timestep) precompute_res = self._precompute(inputs) for _ in range(inputs.num_node): state_t = self.msp_env._state # B x 1 x H query_t = self.context_embed([ inputs, precompute_res.fixed_context, state_t]) # B x 1 x V mask_t = time_step.mask # B x 1 x H query_prime_t = self.mha_layer([ query_t, precompute_res.mha_keys, precompute_res.mha_values, mask_t ]) # B x 1 x V log_p = self.sha_layer([ query_prime_t, precompute_res.sha_keys, mask_t ]) # B x 1 (each) selected_node, prob_selected_node = self._select_node(log_p, mask_t, training) # update the state for next timestep t=t+1 actions = {'inputs': inputs, 'selected_node': selected_node} time_step = self.msp_env.step(actions) # state_t.update(selected_node) schedules.append(tf.stack([selected_node, state_t.mrg_machine], axis=-1)) log_probs.append(tf.math.log(prob_selected_node)) # B x 1 sum_log_probs = tf.reduce_sum(tf.concat(log_probs, axis=-1), axis=-1) # B x V x 2 schedules = tf.concat(schedules, axis=1) return schedules, sum_log_probs
def _precompute(self, inputs): """Precompute keys and values for efficiency.""" _precompute_res = namedtuple( 'precompute', ['fixed_context', 'mha_keys', 'mha_values', 'sha_keys'] ) # B x 1 x H graph_embed = self._get_graph_embed(inputs.node_embed) # precomputation for context embedding # B x 1 x H fixed_context = self.W_CG(graph_embed) # precomputation for MHA and SHA # B x 1 x V x H (each) mha_keys, mha_values, sha_keys = tf.split( self.W_KVK(tf.expand_dims(inputs.node_embed, axis=-3)), num_or_size_splits=3, axis=-1 ) results = _precompute_res( fixed_context=fixed_context, mha_keys=mha_keys, mha_values=mha_values, sha_keys=sha_keys ) return results def _get_graph_embed(self, node_embed): if self.aggregation_graph == "sum": graph_embed = tf.reduce_sum(node_embed, axis=-2) elif self.aggregation_graph == "max": graph_embed = tf.reduce_max(node_embed, axis=-2) elif self.aggregation_graph == "mean": graph_embed = tf.reduce_mean(node_embed, axis=-2) else: # dissable graph embedding graph_embed = tf.reduce_sum(node_embed, axis=-2) * 0.0 return tf.expand_dims(graph_embed, axis=-2) def _select_node(self, log_p, mask, training): """Select the node based on log-probabilities return by model""" tf.assert_equal(tf.reduce_any(tf.math.is_nan(log_p)), False, message="Log probabilities over the nodes should be defined.") # B x 1 x V probs = tf.exp(log_p) tf.assert_equal(tf.reduce_all(probs == probs), True, message='Probs should not contain any nans') is_decoding_correct = lambda selected_node: tf.reduce_all( tf.equal( tf.gather_nd( tf.squeeze(mask, axis=1), selected_node, batch_dims=1), 0 ) ) if training: dist = tfp.distributions.Categorical(probs=probs, dtype=tf.int64) selected_node_temp = tf.squeeze(dist.sample(1), axis=0) # Check if sampling was correct is_decoding_incorrect = lambda selected_node: tf.logical_not( is_decoding_correct(selected_node) ) def body(selected_node_temp): tf.print('Sampled bad values, resampling!') return [tf.squeeze(dist.sample(1), axis=0)] selected_node = tf.while_loop( is_decoding_incorrect, body, [selected_node_temp], parallel_iterations=1 )[0] else: selected_node = tf.math.argmax(probs, axis=-1) tf.assert_equal(is_decoding_correct(selected_node), True, message="Greedy decoding: infeasible action has maximum probability") # B x 1 prob_selected_node = tf.gather(tf.squeeze(probs), selected_node, batch_dims=1) return selected_node, prob_selected_node