【Pytorch】基于CNN手写汉字的识别的错误 ValueError: invalid literal for int() with base 10: 'HWDB1_data' 该如何解决?

我在尝试【Pytorch】基于CNN手写汉字的识别中有错误不知道怎么解决
源代码是:


import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image

EPOCH = 10 #训练几次
BATCH_SIZE = 50 #数据集划分
LR = 0.001 #学习率
DOWNLOAD_MNIST = False 


def classes_txt(root, out_path, num_class=None):
    '''
    write image paths (containing class name) into a txt file.
    :param root: data set path
    :param out_path: txt file path
    :param num_class: how many classes needed
    :return: None
    '''
    dirs = os.listdir(root) # 列出根目录下所有类别所在文件夹名
    if not num_class:        # 不指定类别数量就读取所有
        num_class = len(dirs)

    if not os.path.exists(out_path): # 输出文件路径不存在就新建
        f = open(out_path, 'w')
        f.close()
    # 如果文件中本来就有一部分内容,只需要补充剩余部分
    # 如果文件中数据的类别数比需要的多就跳过
    with open(out_path, 'r+') as f:
        try:
            end = int(f.readlines()[-1].split('/')[-2]) + 1
        except:
            end = 0
        if end < num_class - 1:
            dirs.sort()
            dirs = dirs[end:num_class]
            for dir in dirs:
                files = os.listdir(os.path.join(root, dir))
                for file in files:
                    f.write(os.path.join(root, dir, file) + '\n')

class MyDataset(Dataset):
    def __init__(self, txt_path, num_class, transforms=None):
        super(MyDataset, self).__init__()
        images = [] # 存储图片路径
        labels = [] # 存储类别名,在本例中是数字
        # 打开上一步生成的txt文件
        with open(txt_path, 'r') as f:
            for line in f:
                if int(line.split('/')[-2]) >= num_class:  # 只读取前 num_class 个类
                    break
                line = line.strip('\n')
                images.append(line)
                labels.append(int(line.split('/')[-2]))
        self.images = images
        self.labels = labels
        self.transforms = transforms # 图片需要进行的变换,ToTensor()等等

    def __getitem__(self, index):
        image = Image.open(self.images[index]).convert('RGB') # 用PIL.Image读取图像
        label = self.labels[index]
        if self.transforms is not None:
            image = self.transforms(image) # 进行变换
        return image, label

    def __len__(self):
        return len(self.labels)



    
class NetSmall(nn.Module):
    def __init__(self):
        super(NetSmall, self).__init__()
        self.conv1 = nn.Sequential(
                nn.Conv2d(
                        in_channels = 1, # 输入图片的高度
                        out_channels = 16, # 16个特征卷积过滤器
                        kernel_size = 5, #卷积宽度(长度)
                        stride = 1, # 步长
                        padding = 2, # 扩展图片边缘长宽度
                            # if stirder = 1, padding = (kernel_size-1)/2
                ), # ->(16, 64, 64))
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 2), # 池化层→筛选重要信息 !!取一定区域内最大值!!
        ) # -> (16, 32, 32))
        self.conv2 = nn.Sequential(
                nn.Conv2d(16, 32, 5, 1, 2),
                nn.ReLU(),
                nn.MaxPool2d(2)
        ) # -> (32, 16, 16))
        self.out = nn.Linear(32 * 16 * 16, 100)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output

    def __init__(self):
        super(NetSmall, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3) # 3个参数分别是in_channels,out_channels,kernel_size,还可以加padding
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(2704, 512)
        self.fc2 = nn.Linear(512, 84)
        self.fc3 = nn.Linear(84, 100) 

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 2704)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 首先将训练集和测试集文件途径和文件名以txt保存在一个文件夹中
root = 'C:/Users/14394/Desktop/chinese100/HWDB1_data' # 这是我文件的储存位置
classes_txt(root + '/train', root+'/train.txt')
classes_txt(root + '/test', root+'/test.txt')

# 由于我的数据集图片尺寸不一,因此要进行resize,这里还可以加入数据增强,灰度变换,随机剪切等等
transform = transforms.Compose([transforms.Resize((64,64)), # 将图片大小重设为 64 * 64
                                transforms.Grayscale(),
                                transforms.ToTensor()])

train_set = MyDataset(root + '/train.txt', num_class=100, transforms=transform) # num_class 选取100种汉字  提出图片和标签
test_set = MyDataset(root + '/test.txt', num_class =100, transforms = transform)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True) # 装进迭代器中
test_loader = DataLoader(test_set, batch_size=5473, shuffle=True) 

device = torch.device('cpu')
for step, (x,y) in enumerate(test_loader):
    test_x, labels_test = x.to(device), y.to(device)
    
model = NetSmall()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss() 
model.to(device)
# 训练集

for epoch in range(EPOCH):
    for step, (x,y) in enumerate(train_loader):
        picture, labels = x.to(device), y.to(device)
        
        output = model(picture)
        loss = loss_func(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 50 == 0:
            test_output = model(test_x)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = (pred_y == labels_test).sum().item() / labels_test.size(0)
            print('Epoch:', epoch, '| train loss:%.4f' % loss.data, '| test accuracy:', accuracy)

print('Finish training')

以下是错误

D:\Anaconda3\envs\py38\python.exe C:\Users\14394\Desktop\chinese100\Chinese_code.py 
Traceback (most recent call last):
  File "C:\Users\14394\Desktop\chinese100\Chinese_code.py", line 138, in 
    train_set = MyDataset(root + '/train.txt', num_class=100, transforms=transform) # num_class 选取100种汉字  提出图片和标签
  File "C:\Users\14394\Desktop\chinese100\Chinese_code.py", line 59, in __init__
    if int(line.split('/')[-2]) >= num_class:  # 只读取前 num_class 个类
ValueError: invalid literal for int() with base 10: 'HWDB1_data'

Process finished with exit code 1

请问这个错误我该如何解决?

可以输出下line.split('/')[-2]看看是不是有空格或者其他字符导致的

line.split('/')[-2]导致的报错,可能是index超出。原因或许是line中不存在/,line中的分割符也许是\