Source code for msp.models._models

"""
The :mod:`mps.models._models` module defines encoder-decoder architecture to 
solve MSP problem.
"""
import tensorflow as tf
from tensorflow.keras import Model


[docs]class MSPModel(Model): def __init__(self, encoder_class, encoder_params, decoder_class, decoder_params, *args, **kwargs): """ """ super(MSPModel, self).__init__(*args, **kwargs) self.encoder = encoder_class( encoder_params.get('units', 128), encoder_params.get('layers', 3), name='ggcn_encoder' ) self.decoder = decoder_class( decoder_params.get('units', 128), use_bias = decoder_params.get('use_bias', False), n_heads = decoder_params.get('n_heads', 8), aggregation_graph = decoder_params.get('aggregation_graph', 'mean'), tanh_clipping = decoder_params.get('tanh_clipping', 10), name='attention_decoder' # **kwargs <---- use pop instead of get )
[docs] @tf.function def call(self, inputs, training=None): """ """ outputs = self.encoder(inputs, training) outputs = self.decoder(outputs, training) return outputs