使用tensorflow构建RNN网络模型,把其中GRUCell改为LSTMCell
self.edge_update = tf.keras.layers.GRUCell(self.hparams.link_state_dim, name="edge_update")
self.path_update = tf.keras.layers.GRUCell(self.hparams.path_state_dim, name="path_update")
把其中GRUCell改为LSTMCell
IndexError: list index out of range
应该是这里出现了问题:
outputs, path_state = tf.nn.dynamic_rnn(self.path_update,
link_inputs,
sequence_length=lens,
initial_state = path_state,
dtype=tf.float32)