是在处理mnist数据集时遇到的问题。
X_train, y_train = load_mnist('E:\数据集\MINIST', kind ='train')_
#打印出来的X_train的形状为(60000, 784)
#打印出来的y_train的形状为(60000,)
在下面这段代码中遇到了问题:
for i in range(10):
img = X_train[y_train == i][0].reshape(28, 28)
test = X_train[y_train == i][0]
# print(y_train == i)#出来的结果是[true ,false...]
# print((y_train == i).shape)#出来的结果是(6000,)
# print(test.shape)#(784,)
怎么理解 X_train[y_train == i][0]这一部分呢?
X_train明明是二维数组,X_train[][]第一个[]内不应该是指定行,第二个[]不应该是指定列,出来的不应该是一个元素吗?
还有y_train == i出来的是6000个true/false值组成的,怎么理解?
楼主可以试试先打印 X_train[y_train == i] 这个,这是一个推导式,返回 X_train 中满足 y_train ==i 条件的记录。
后面那个 [0] 操作则是获取索引位置为 0 的数据。
可以尝试先打印print(X_train[True]),会发现X_train中所有元素打打印出来,X_train[y_train == i]这个只是打印X_train中和y_train索引位置一一对应位置为True的值,其他为False的值不打印。