labels = df.category.tolist() 这行代码报错
import os
import numpy as np
#import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import time
import torchvision
from torch.autograd import Variable
from torchvision import datasets, transforms
import os # os包集成了一些对文件路径和目录进行操作的类
from PIL import Image
from torch import optim
#from models.Category import category
class TrafficData(Dataset):
def __init__(self, path, train=True):
super(TrafficData, self).__init__()
df = pd.read_csv(os.path.join(path, 'annotation.csv'))
labels = df.category.tolist()
image_files = df.file_name.tolist()
self.path = path
del df
self.transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((128, 128)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307), (0.3081))
])
if train:
self.image_files = image_files[:int(len(image_files)*0.9)]
self.labels = labels[:int(len(image_files)*0.9)]
else:
self.image_files = image_files[int(len(image_files)*0.9):]
self.labels = labels[int(len(image_files)*0.9):]
def __getitem__(self, index):
image = Image.open(os.path.join(self.path + '/images/', self.image_files[index]))
return self.transform(image), self.labels[index]
def __len__(self):
return len(self.image_files)
TrafficData1=TrafficData('C:/Users/17204/Desktop/traffic/train', train=True)
TrafficData2==TrafficData('C:/Users/17204/Desktop/traffic/test', train=False)
train_loader = DataLoader(dataset=TrafficData1,
batch_size=32, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=TrafficData2,
batch_size=32, shuffle=True, drop_last=True)
AttributeError Traceback (most recent call last)
in
46
47
48 TrafficData1=TrafficData('C:/Users/17204/Desktop/traffic/train', train=True)
49
50 TrafficData2==TrafficData('C:/Users/17204/Desktop/traffic/test', train=False)
in init(self, path, train)
22 super(TrafficData, self).init()
23 df = pd.read_csv(os.path.join(path, 'annotation.csv'))
24 labels = df.category.tolist()
25 image_files = df.file_name.tolist()
26 self.path = path
~\anaconda3\envs\pytorch\lib\site-packages\pandas\core\generic.py in getattr(self, name)
5139 if self._info_axis._can_hold_identifiers_and_holds_name(name):
5140 return self[name]
-> 5141 return object.getattribute(self, name)
5142
5143 def setattr(self, name: str, value) -> None:
AttributeError: 'DataFrame' object has no attribute 'category'
怎么解决
r'C:/Users/17204/Desktop/traffic/train'
加个r试试