Source code for msp.layers._context_embedding

"""
The :mod:`mps.layers._context_embedding` module defines ContextEmbedding layer
inherited from tensorflow layer.
"""
import tensorflow as tf
from tensorflow.keras.layers import Layer


[docs]class ContextEmbedding(Layer): def __init__(self, units, use_bias=False, **kwargs): """ """ super(ContextEmbedding, self).__init__(**kwargs) self.units = units self.use_bias = use_bias
[docs] def build(self, input_shape): """Create the state of the layer (weights)""" with tf.name_scope('W_placeholder'): # learnable first and last node at timestep t=1 self.W_context_placeholder = self.add_weight( shape=(2*self.units,), initializer=tf.random_uniform_initializer(minval=-1, maxval=1), trainable=True, name='W_context_placeholder' ) with tf.name_scope('W_step_context'): # projections for first and last node W_step_context_shape = tf.TensorShape((None, None, 2*self.units)) self.W_step_context = tf.keras.layers.Dense( self.units, use_bias=self.use_bias) self.W_step_context.build(W_step_context_shape) super().build(input_shape)
[docs] def call(self, inputs): """ """ node_embed = inputs[0].node_embed fixed_context = inputs[1] state = inputs[2] batch_size = tf.shape(node_embed)[0] # first and last node embedding first_step_first_last_embed = lambda : tf.broadcast_to( self.W_context_placeholder[None, None, :], shape=(batch_size, 1, tf.shape(self.W_context_placeholder)[0]) ) other_step_first_last_embed = lambda : tf.concat( [ tf.gather(node_embed, state.first_node, batch_dims=1), tf.gather(node_embed, state.last_node, batch_dims=1) ], axis=-1 ) is_first_step = tf.equal(state.get_step_count(), 1) # B x 1 x 2H first_last_embed = tf.cond(is_first_step, first_step_first_last_embed, other_step_first_last_embed) # B x 1 x H context_embed = fixed_context + self.W_step_context(first_last_embed) return context_embed