在做图神经网络训练,想咨询下,PygNodePropPredDataset 能读取本地数据吗? 我现在的代码是下载数据集,但是我第二次登录时,数据集已经下好了,如何才能直接读取本地数据集呢?
from torch.utils.data import DataLoader, Dataset
import torch_geometric
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.data import NeighborSampler
from torch_geometric.nn import SAGEConv
import os.path as osp
import pandas as pd
import numpy as np
import collections
from pandas.core.common import flatten
# importing obg datatset
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from pandas.core.common import flatten
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(rc={'figure.figsize':(16.7,8.27)})
sns.set_theme(style="ticks")
import collections
from scipy.special import softmax
import umap
# download and loading the obg dataset
root = osp.join(osp.dirname(osp.realpath('./')), 'data', 'products', 'ogbn-products')
dataset = PygNodePropPredDataset('ogbn-products', root)
这个是代码库自带的,
自带两个判断:
其中这部分的逻辑代码如下(Pycharm中 ctrl+左击第14行中的 PygNodePropPredDataset进入第三方库代码部分):
class PygNodePropPredDataset:
...
# check version
# First check whether the dataset has been already downloaded or not.
# If so, check whether the dataset version is the newest or not.
# If the dataset is not the newest version, notify this to the user.
if osp.isdir(self.root) and (not osp.exists(osp.join(self.root, 'RELEASE_v' + str(self.meta_info['version']) + '.txt'))):
print(self.name + ' has been updated.')
if input('Will you update the dataset now? (y/N)\n').lower() == 'y':
shutil.rmtree(self.root)
如果你想无论是否最新版本都不下载,只是第一次下载的话,修改这段代码即可。
如有问题及时沟通
源码在这里,阔以取研究一下~
import pandas as pd
import shutil, os
import os.path as osp
import torch
import numpy as np
from torch_geometric.data import InMemoryDataset
from ogb.utils.url import decide_download, download_url, extract_zip
from ogb.io.read_graph_pyg import read_graph_pyg
class PygGraphPropPredDataset(InMemoryDataset):
def __init__(self, name, root = 'dataset', transform=None, pre_transform = None, meta_dict = None):
'''
- name (str): name of the dataset
- root (str): root directory to store the dataset folder
- transform, pre_transform (optional): transform/pre-transform graph objects
- meta_dict: dictionary that stores all the meta-information about data. Default is None,
but when something is passed, it uses its information. Useful for debugging for external contributers.
'''
self.name = name ## original name, e.g., ogbg-molhiv
if meta_dict is None:
self.dir_name = '_'.join(name.split('-'))
# check if previously-downloaded folder exists.
# If so, use that one.
if osp.exists(osp.join(root, self.dir_name + '_pyg')):
self.dir_name = self.dir_name + '_pyg'
self.original_root = root
self.root = osp.join(root, self.dir_name)
master = pd.read_csv(os.path.join(os.path.dirname(__file__), 'master.csv'), index_col = 0)
if not self.name in master:
error_mssg = 'Invalid dataset name {}.\n'.format(self.name)
error_mssg += 'Available datasets are as follows:\n'
error_mssg += '\n'.join(master.keys())
raise ValueError(error_mssg)
self.meta_info = master[self.name]
else:
self.dir_name = meta_dict['dir_path']
self.original_root = ''
self.root = meta_dict['dir_path']
self.meta_info = meta_dict
# check version
# First check whether the dataset has been already downloaded or not.
# If so, check whether the dataset version is the newest or not.
# If the dataset is not the newest version, notify this to the user.
if osp.isdir(self.root) and (not osp.exists(osp.join(self.root, 'RELEASE_v' + str(self.meta_info['version']) + '.txt'))):
print(self.name + ' has been updated.')
if input('Will you update the dataset now? (y/N)\n').lower() == 'y':
shutil.rmtree(self.root)
self.download_name = self.meta_info['download_name'] ## name of downloaded file, e.g., tox21
self.num_tasks = int(self.meta_info['num tasks'])
self.eval_metric = self.meta_info['eval metric']
self.task_type = self.meta_info['task type']
self.__num_classes__ = int(self.meta_info['num classes'])
self.binary = self.meta_info['binary'] == 'True'
super(PygGraphPropPredDataset, self).__init__(self.root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
def get_idx_split(self, split_type = None):
if split_type is None:
split_type = self.meta_info['split']
path = osp.join(self.root, 'split', split_type)
# short-cut if split_dict.pt exists
if os.path.isfile(os.path.join(path, 'split_dict.pt')):
return torch.load(os.path.join(path, 'split_dict.pt'))
train_idx = pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header = None).values.T[0]
valid_idx = pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header = None).values.T[0]
test_idx = pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header = None).values.T[0]
return {'train': torch.tensor(train_idx, dtype = torch.long), 'valid': torch.tensor(valid_idx, dtype = torch.long), 'test': torch.tensor(test_idx, dtype = torch.long)}
@property
def num_classes(self):
return self.__num_classes__
@property
def raw_file_names(self):
if self.binary:
return ['data.npz']
else:
file_names = ['edge']
if self.meta_info['has_node_attr'] == 'True':
file_names.append('node-feat')
if self.meta_info['has_edge_attr'] == 'True':
file_names.append('edge-feat')
return [file_name + '.csv.gz' for file_name in file_names]
@property
def processed_file_names(self):
return 'geometric_data_processed.pt'
def download(self):
url = self.meta_info['url']
if decide_download(url):
path = download_url(url, self.original_root)
extract_zip(path, self.original_root)
os.unlink(path)
shutil.rmtree(self.root)
shutil.move(osp.join(self.original_root, self.download_name), self.root)
else:
print('Stop downloading.')
shutil.rmtree(self.root)
exit(-1)
def process(self):
### read pyg graph list
add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True'
if self.meta_info['additional node files'] == 'None':
additional_node_files = []
else:
additional_node_files = self.meta_info['additional node files'].split(',')
if self.meta_info['additional edge files'] == 'None':
additional_edge_files = []
else:
additional_edge_files = self.meta_info['additional edge files'].split(',')
data_list = read_graph_pyg(self.raw_dir, add_inverse_edge = add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files, binary=self.binary)
if self.task_type == 'subtoken prediction':
graph_label_notparsed = pd.read_csv(osp.join(self.raw_dir, 'graph-label.csv.gz'), compression='gzip', header = None).values
graph_label = [str(graph_label_notparsed[i][0]).split(' ') for i in range(len(graph_label_notparsed))]
for i, g in enumerate(data_list):
g.y = graph_label[i]
else:
if self.binary:
graph_label = np.load(osp.join(self.raw_dir, 'graph-label.npz'))['graph_label']
else:
graph_label = pd.read_csv(osp.join(self.raw_dir, 'graph-label.csv.gz'), compression='gzip', header = None).values
has_nan = np.isnan(graph_label).any()
for i, g in enumerate(data_list):
if 'classification' in self.task_type:
if has_nan:
g.y = torch.from_numpy(graph_label[i]).view(1,-1).to(torch.float32)
else:
g.y = torch.from_numpy(graph_label[i]).view(1,-1).to(torch.long)
else:
g.y = torch.from_numpy(graph_label[i]).view(1,-1).to(torch.float32)
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
print('Saving...')
torch.save((data, slices), self.processed_paths[0])
我之前用你的代码跑过,第一次下载了之后不需要重新下载了。不过我是自己手动下载了之后放到指定目录的,程序里自己下载很慢。用你的代码跑了之后,数据集的路径是这样子的: