Source code for msp.layers._ggcn

"""
The :mod:`mps.layers._ggcn` module defines `Gated Graph Convolution Net` layer
inherited from tensorflow layer.

Reference:
    - V. P. Dwivedi, C. K. Joshi, T. Laurent, Y. Bengio, and X. Bresson. 
    `Benchmarking graph neural networks. arXiv preprint arXiv:2003.00982, 2020`.
"""
import tensorflow as tf
from tensorflow.keras.layers import Layer

from msp.graphs import MSPEmbedGraph


[docs]class GGCNLayer(Layer): def __init__(self, units, *args, activation='relu', use_bias=True, normalization='batch', aggregation='mean', **kwargs): """Gated Graph Convolution Layer. Args: units: Number of hidden dimensions. activation: Activation function to use. use_bise: Boolean, whether the layer uses a bias vector. normalization: Type of normalization `batch` or `layer`. aggregation: Method to aggreate the messaages from neighbor nodes. Input shape: Instance of `MSPEmbedGraph` consisting of atleast following tensor: Tensor `adj_matrix` with shape: `(batch_size, num_node, num_node)`. Tensor `node_features` with shape: `(batch_size, num_node, input_dim)`. Tensor `edge_features` with shape: `(batch_size, num_node, num_node, input_dim)`. Output shape: Instance of `MSPEmbedGraph` consisting of atleast following tensor: Tensor `adj_matrix` with shape: `(batch_size, num_node, num_node)`. Tensor `node_features` with shape: `(batch_size, num_node, units)`. Tensor `edge_features` with shape: `(batch_size, num_node, num_node, units)`. """ super(GGCNLayer, self).__init__(*args, **kwargs) self.units = units self.activation = tf.keras.activations.get(activation) self.use_bias = use_bias self.normalization= normalization self.aggregation = aggregation
[docs] def build(self, input_shape): """Create the state of the layer (weights)""" node_embed_shape = input_shape.node_embed edge_embed_shape = input_shape.edge_embed with tf.name_scope('node'): with tf.name_scope('U'): self.U = tf.keras.layers.Dense(self.units, use_bias=self.use_bias) self.U.build(node_embed_shape) with tf.name_scope('V'): self.V = tf.keras.layers.Dense(self.units, use_bias=self.use_bias) self.V.build(node_embed_shape) with tf.name_scope('norm'): self.norm_h = { "batch": tf.keras.layers.BatchNormalization(), "layer": tf.keras.layers.LayerNormalization() }.get(self.normalization, None) if self.norm_h: self.norm_h.build(node_embed_shape) with tf.name_scope('edge'): with tf.name_scope('A'): self.A = tf.keras.layers.Dense(self.units, use_bias=self.use_bias) self.A.build(edge_embed_shape) with tf.name_scope('B'): self.B = tf.keras.layers.Dense(self.units, use_bias=self.use_bias) self.B.build(node_embed_shape) with tf.name_scope('C'): self.C = tf.keras.layers.Dense(self.units, use_bias=self.use_bias) self.C.build(node_embed_shape) with tf.name_scope('norm'): self.norm_e = { 'batch': tf.keras.layers.BatchNormalization(), 'layer': tf.keras.layers.LayerNormalization(axis=-1) }.get(self.normalization, None) if self.norm_e: self.norm_e.build(edge_embed_shape) super().build(input_shape)
[docs] def call(self, inputs, training=None): """ """ adj_matrix = inputs.adj_matrix h = inputs.node_embed e = inputs.edge_embed # Edges Featuers Ae = self.A(e) Bh = self.B(h) Ch = self.C(h) e = self._update_edges(e, [Ae, Bh, Ch], training) edge_gates = tf.sigmoid(e) # Nodes Features Uh = self.U(h) Vh = self.V(h) h = self._update_nodes( h, [Uh, self._aggregate(Vh, edge_gates, adj_matrix)], training ) outputs = MSPEmbedGraph( *( inputs.adj_matrix, inputs.node_features, inputs.edge_features, inputs.job_assignment ), node_embed = h, edge_embed = e ) return outputs
def _update_edges(self, e, transformations:list, training): """Update edges features""" Ae, Bh, Ch = transformations e_new = Ae + tf.expand_dims(Bh, axis=1) + tf.expand_dims(Ch, axis=2) # normalization if self.norm_e: e_new = self.norm_e(e_new, training) # activation e_new = self.activation(e_new) # skip/residual Connection e_new = e + e_new return e_new def _update_nodes(self, h, transformations:list, training): """Update node feature.""" Uh, aggregated_messages = transformations h_new = tf.math.add_n([Uh, aggregated_messages]) # Normalization if self.norm_h: h_new = self.norm_h(h_new, training) # Activation h_new = self.activation(h_new) # Skip/residual Connection h_new = h + h_new return h_new def _aggregate(self, Vh, edge_gates, adj_matrix): """Aggregate neighbors messages.""" # Reshape as edge_gates Vh = tf.broadcast_to(tf.expand_dims(Vh, axis=1), tf.shape(edge_gates)) #Vh = tf.broadcast_to(tf.expand_dims(Vh, axis=1), edge_gates.shape) # Gating mechanism Vh = edge_gates * Vh # Apply graph structure neighbor_mask = tf.broadcast_to(tf.expand_dims(adj_matrix, axis=-1), tf.shape(Vh)) # neighbor_mask = tf.broadcast_to(tf.expand_dims(adj_matrix, axis=-1), Vh.shape) Vh = Vh * neighbor_mask # message aggregation if self.aggregation == 'mean': return tf.divide( tf.math.reduce_sum(Vh, axis=2), tf.math.reduce_sum(neighbor_mask, axis=2)) elif self.aggregation == 'sum': return tf.math.reduce_sum(Vh, axis=2) else: return tf.reduce_max(Vh, axis=2)