from keras import layers
import tensorflow as tf
x = tf.random.normal([32,10,8])
final_memory_state = tf.random.normal([32,4])
final_carry_state = tf.random.normal([32,4])
for num in range(8):
#xt = tf.reshape(x[:,:,num],[x.shape[0],x.shape[1],1])
previous_state = (final_memory_state,final_carry_state)
xt = x[:,:,num]
print(xt.shape)
#xt = x[:,0,:] # 得到一个时间戳的输入
cell = tf.compat.v1.nn.rnn_cell.LSTMCell(4,state_is_tuple=True)
final_memory_state, final_carry_state = cell(xt,previous_state) # 前向计算
#print(final_memory_state.shape)
您好,我是有问必答小助手,您的问题已经有小伙伴解答了,您看下是否解决,可以追评进行沟通哦~
如果有您比较满意的答案 / 帮您提供解决思路的答案,可以点击【采纳】按钮,给回答的小伙伴一些鼓励哦~~
ps:问答VIP仅需29元,即可享受5次/月 有问必答服务,了解详情>>>https://vip.csdn.net/askvip?utm_source=1146287632
我试了一下,报错结果如下:解决方法是安装2.2及以上版本的TensorFlow