tf.compat.v1.get_variable没有重用已经创建的变量,而是将已经创建的变量进行数值更新。
def foo():
with tf.compat.v1.variable_scope("foo", reuse=tf.compat.v1.AUTO_REUSE):
v = tf.compat.v1.get_variable("v", [1])
return v
v1 = foo() # Creates v.
v2 = foo() # Gets the same, existing v.
assert v1 == v2
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
D:\Temp/ipykernel_12096/3386910621.py in <module>
6 v1 = foo() # Creates v.
7 v2 = foo() # Gets the same, existing v.
----> 8 assert v1 == v2
AssertionError:
查看了v1,与,v2的数据后发现
<tf.Variable 'foo/v:0' shape=(1,) dtype=float32, numpy=array([-1.6809169], dtype=float32)>
<tf.Variable 'foo/v:0' shape=(1,) dtype=float32, numpy=array([-0.75641733], dtype=float32)>
两者之间只是数据发生了变化
希望assert 函数不会报错
assert v1 is v2
你可以改成这样