做minist手写数字分类识别的时候遇到这个报错,如何解决?(语言-python)

做minist手写数字分类识别的时候遇到这个报错

这个是做交叉验证数据集切分的时候报的错

img


```python
skflods = StratifiedKFold(n_splits=3,random_state=None) #把数据集切分成3份
for train_index,test_index in skflods.split(X_train,y_train_5): #拿到train和test的index
    clone_clf = clone(sgd_clf) # 克隆和之前参数一样的模型
    X_train_folds = X_train[train_index]
    y_train_folds= y_train_5[train_index]
    X_test_folds=X_train[test_index]
    y_test_folds=y_train_5[test_index]
    
    clone_clf.fit(X_train_folds,y_train_folds)
    y_pred = clone_clf.predict(X_test_folds)
    n_correct = sum(y_pred == y_test_folds)
    print(n_correct/len(y_pred))

```

看起来循环中 数据集没有有效取到数据,可以 debug 或 print 看一下 X_train_folds 等数组。

先print一下以下内容检查一下:

  • X_train.shape, y_train_5.shape
  • train_index, test_index