y_train[y_train == label1] = 0怎么理解?

一般是训练一个binary的二分类任务
:param label1
:param label2
'''
x_train, y_train, x_test, y_test = return_part_mnist([label1, label2])
print(np.shape(x_train), np.shape(y_train), np.shape(x_test), np.shape(y_test))

y_train[y_train == label1] = 0
y_train[y_train == label2] = 1
y_test[y_test == label1] = 0
y_test[y_test == label2] = 1
print(y_train[:128])
print(y_test[:128])

代码中的 y_train 是一个包含两种标签的数组,分别为 label1 和 label2。

  • 代码中,y_train[y_train == label1] = 0 语句的作用是将 y_train 中所有等于 label1 的元素修改为 0。
  • 同理,y_train[y_train == label2] = 1 将 y_train 中所有等于 label2 的元素修改为 1。

这样,就实现了将两种标签转化为二分类标签的目的。对于 y_test 数组的处理方式也是类似的。

y_train[y_train == label1] = 0,是将训练模型输出为标签 label1 的样本赋值为 0,y_train[y_train == label2] = 1 将训练模型输出为标签 label1 的样本赋值为 1。
接下来统计 ytrain/ytest 累计有多少个 0,多少个 1,就可以计算训练精度/检验精度。