我是在服务器上跑的,所以这几行,地址需要他改一下。看能不能优化到0.7,csv文件发邮箱
import numpy as np
import pandas as pd
import torch
from torch import Tensor
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import matplotlib.pyplot as plt
from sklearn import metrics
import glob
from scipy import signal
from torch_geometric.nn import GATConv
#--------------------------------------计算节点
node_path = r'/data1/ljh/eeg/fromgithub/save/mean_after_process_features.csv'
x_node = pd.read_csv(node_path)
x_node = x_node.values[:, 1:]
#############################################################此时x读取的是对应的节点。
#---------------------------------------标签
ypath = r'/data1/ljh/eeg/fromgithub/save/mwan_label.csv'
label = pd.read_csv(ypath)
#############################################################此时label读取的是对应的标签。
label = label.values[:, 1:]
#--------------------------------------计算边
# read csv file from the path
xpath = r'/data1/ljh/eeg/fromgithub/save/mean_adjacent_matrix.csv'
x = pd.read_csv(xpath)
print(x)
print('type:\n', type(x))
print('x的大小:\n', x.shape)
y = x.values[:, 1:]
print(y)
#############################################################此时y读取的是相关系数矩阵。
#这里打乱一下顺序
np.random.seed(112)
np.random.shuffle(x_node)
np.random.seed(112)
np.random.shuffle(label)
np.random.seed(112)
np.random.shuffle(y)
print('x_node:\n', x_node)
print('label:\n', label)
print('type:\n', type(y))
print('y的大小:\n', y.shape)
y_tensor = torch.from_numpy(y)
x_node = torch.Tensor(x_node)
label = torch.Tensor(label)
#---------------------------------------------
threshold = 0.5
train_source_node, train_target_node = torch.where(y_tensor > threshold)
train_source_node = train_source_node.unsqueeze(0)
train_target_node = train_target_node.unsqueeze(0)
train_edge_index = torch.cat((train_source_node, train_target_node), dim=0)
# train_edge_index = train_edge_index.cuda()
# 构造网络
class Net(torch.nn.Module):
def __init__(self):
super(Net,
你把代码复制到记事本发给我看看