Source code for msp.models.encoders._ggcn

"""
The :mod:`mps.models._ggcn` module defines encoder architecture based
on `Gated Graph Convolution` layer.
"""
import tensorflow as tf
from tensorflow.keras import Model

from msp.graphs import MSPEmbedGraph
from msp.layers import GGCNLayer


[docs]class GGCNEncoder(Model): def __init__(self, units, layers, *args, activation='relu', **kwargs): """ """ super(GGCNEncoder, self).__init__(*args, **kwargs) self.initial_layer_1 = tf.keras.layers.Dense( units, name='initial_layer_for_node_features') self.initial_layer_2 = tf.keras.layers.Dense( units, name='initial_layer_for_edge_features') self.ggcn_layers = [GGCNLayer(units, activation=activation) for _ in range(layers)]
[docs] def call(self, inputs, training=None): """ """ node_features_t = self.initial_layer_1(inputs.node_features) edge_features_t = self.initial_layer_2(inputs.edge_features) outputs = MSPEmbedGraph(**inputs._asdict(), node_embed=node_features_t, edge_embed=edge_features_t) for layer in self.ggcn_layers: outputs = layer(outputs, training) return outputs