import torch, os
import numpy as np
from dataloaders.dataset_miniimagenet import MiniImagenet
from dataloaders.dataset_miniimagenet_test import MiniImagenettest
from torch.utils.data import DataLoader
import argparse
import json
from torch import optim
import random
from models.myModel import MyModel
from grad_comp import meta_grad_comp,maml_finetune
import logging
import pandas as pd
from imblearn.over_sampling import SMOTE
from copy import deepcopy
def cycle(iterable):
iterator = iter(iterable)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(iterable)
def main(args):
store_dir = os.path.join(args.base_results_dir)
os.makedirs(store_dir,exist_ok=True)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
# flag = 1
train_loss = []
train_acc = [] # 所有分类平均准确率
train_mae = []
train_acc_class = [] # 所有分类各平均准确率(六组)
task_0 = []
task_1 = []
task_2 = []
task_3 = []
task_4 = []
task_5 = []
#######################logger###########################
log_format = '%(levelname)-8s %(message)s'
log_file_name = 'train.log'
logfile = os.path.join(store_dir, log_file_name)
logging.basicConfig(filename=logfile, level=logging.INFO, format=log_format)
logging.getLogger().addHandler(logging.StreamHandler()) # 将日志消息发送至控制台
########################################################
with open(os.path.join(store_dir,'args.json'),'w+') as args_file:
json.dump(args.__dict__, args_file) # 参数信息转化成json格式
if args.svdo:
config = [
('Linear_SVDO', [args.conv_layers , 11]), # 6*64
('bn', [args.conv_layers]), # 64
('relu', [True]),
('Linear_SVDO', [args.conv_layers , args.conv_layers]), # 64*64
('bn', [args.conv_layers]),
('relu', [True]),
('Linear_SVDO', [args.conv_layers , args.conv_layers]),
('bn', [args.conv_layers]),
('relu', [True]),
('flatten', []),
('Linear_SVDO', [11, args.conv_layers]) # 64*6
]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# logging.info(device)
# 初始化网络
model = MyModel(config,threshold=args.threshold,init_linsvdo = args.init_linsvdo, pretrained=False,svdo=args.svdo,custom=True)
model.train() # 训练模式
model = model.to(device)
data_train = MiniImagenet(args.dset_path,phase='train', n_way=args.n_way, k_spt=args.k_spt, k_query=args.k_qry)
db_train = cycle(DataLoader(data_train, batch_size=args.meta_batchsz, shuffle=True, pin_memory=True))
data_val = MiniImagenet(args.dset_path,phase='test', n_way=args.n_way, k_spt=args.k_spt, k_query=args.k_qry)
db_val = DataLoader(data_val, batch_size=1, shuffle=False, pin_memory=True)
meta_optim = optim.Adam(model.parameters(), lr=args.meta_lr) # 构造优化器
kl_weight = min(args.initial_kl_weight+1e-6, 0.01) # 初始kl权重
global_test_steps=args.start_test_step
############################train##########################
for step in range(args.start_step,args.num_iterations+1):
x_spt, y_spt, x_qry, y_qry = next(iter(db_train)) # 每组任务抽样一个batch
x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)
tr_acc_class_mean,tr_MAE_mean,accs,tr_loss = meta_grad_comp(model,args, x_spt, y_spt, x_qry, y_qry, kl_weight, meta_optim, phase='train')
print("step:{}/{}| acc:{}| Task0:{}| Task1:{}| Task2:{}| Task3:{}| Task4:{}| Task5:{}| loss:{}| MAE:{}".format(
step, args.num_iterations, 100*accs[-1], tr_acc_class_mean[0], tr_acc_class_mean[1], tr_acc_class_mean[2],
tr_acc_class_mean[3], tr_acc_class_mean[4], tr_acc_class_mean[5], tr_loss[-1], tr_MAE_mean[0]))
train_loss.append(tr_loss[-1])
train_acc.append(100*accs[-1])
train_mae.append(tr_MAE_mean[0])
train_acc_class.append(tr_acc_class_mean)
task_0.append(tr_acc_class_mean[0])
task_1.append(tr_acc_class_mean[1])
task_2.append(tr_acc_class_mean[2])
task_3.append(tr_acc_class_mean[3])
task_4.append(tr_acc_class_mean[4])
task_5.append(tr_acc_class_mean[5])
if step % 200==0: # 每200步更新KL权重
kl_weight = min(kl_weight+1e-6, 1e-5)
if step % 100 == 0: # 每100步验证
torch.manual_seed(args.val_seed)
torch.cuda.manual_seed_all(args.val_seed)
# torch.backends.cudnn.deterministic = True
np.random.seed(args.val_seed)
random.seed(args.val_seed)
net = MyModel(config, threshold=args.threshold, init_linsvdo=args.init_linsvdo, pretrained=False, svdo=args.svdo, custom=True)
net.load_state_dict(deepcopy(model.state_dict()))
net = net.to(device)
x_spt, y_spt, x_qry, y_qry = next(iter(db_val))
x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
test_acc_class_mean,test_MAE_mean, accs, avg_loss_q, f1 = maml_finetune(net, args, x_spt, y_spt, x_qry, y_qry, kl_weight,phase='test') # 利用val集对模型进行微调并评估模型
print("Eval| acc:{}| Task0:{}| Task1:{}| Task2:{}| Task3:{}| Task4:{}| Task5:{}| loss:{}| MAE:{}".format( 100 * accs,test_acc_class_mean[0],
test_acc_class_mean[1],test_acc_class_mean[2],
test_acc_class_mean[3],test_acc_class_mean[4],
test_acc_class_mean[5],tr_loss[-1], test_MAE_mean))
del net
global_test_steps += 1
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
if step % 1000==0: # 每1000步保存模型和数据
torch.save({'kl_weight':kl_weight,
'step': step,
'test_step':global_test_steps,
'state_dict': model.state_dict()},
os.path.join(store_dir, 'step_'+str(step)+'_person3_smote'+'_model_checkpoint.pt'))
# torch.save(meta_optim.state_dict(), os.path.join(store_dir, 'optim_checkpoint_step'+str(step)+'.pt'))
np.savetxt('.\Results\B-SMALL\Train_loss{}.txt'.format(args.meta_lr), np.array(train_loss))
np.savetxt('.\Results\B-SMALL\Train_acc{}.txt'.format(args.meta_lr), np.array(train_acc))
np.savetxt('.\Results\B-SMALL\Train_mae{}.txt'.format(args.meta_lr), np.array(train_mae))
np.savetxt('.\Results\B-SMALL\Task0_acc{}.txt'.format(args.meta_lr), np.array(task_0))
np.savetxt('.\Results\B-SMALL\Task1_acc{}.txt'.format(args.meta_lr), np.array(task_1))
np.savetxt('.\Results\B-SMALL\Task2_acc{}.txt'.format(args.meta_lr), np.array(task_2))
np.savetxt('.\Results\B-SMALL\Task3_acc{}.txt'.format(args.meta_lr), np.array(task_3))
np.savetxt('.\Results\B-SMALL\Task4_acc{}.txt'.format(args.meta_lr), np.array(task_4))
np.savetxt('.\Results\B-SMALL\Task5_acc{}.txt'.format(args.meta_lr), np.array(task_5))
# if step % args.num_iterations==0:
# torch.save({'kl_weight': kl_weight,
# 'step': step,
# 'test_step': global_test_steps,
# 'state_dict': model.state_dict()},
# os.path.join(store_dir,'model{}.pt'.format(flag)))
# flag = flag + 1
if name == 'main':
argparser = argparse.ArgumentParser()
argparser.add_argument('--base_results_dir',help='Enter path for base directory to store results', default='./Results/B-SMALL/model') # 模型文件保存路径
argparser.add_argument('--num_iterations', type=int, help='epoch number', default=10000) # 训练片段数
argparser.add_argument('--conv_layers',type=int,default=64) # 卷积层
argparser.add_argument('--seed',type=int,default=100) # 随机数生成器的种子
argparser.add_argument('--val_seed',type=int,default=5)
argparser.add_argument('--dset_path', default='./data/12组数据测试/data1/')
argparser.add_argument('--n_way', type=int, help='n way', default=6) # 分类数
argparser.add_argument('--k_spt', type=int, help='k shot for train support set', default=1)
argparser.add_argument('--k_qry', type=int, help='k shot for train query set', default=1)
argparser.add_argument('--meta_batchsz', type=int, help='meta batch size, namely task num', default=8) # 任务数
argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3) # 外部学习率
argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=1e-2) # 内部学习率
argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5) # 训练的梯度步长
argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10) # 测试的梯度步长(test.py用到)
argparser.add_argument('--threshold', type=float, help='for removing weights (SVD)', default=3.0) # 阈值
argparser.add_argument('--initial_kl_weight', type=float,default=1e-6, help='Coeff for KL Loss') # KL权重参数
argparser.add_argument('--start_step', type=int,default=1)
argparser.add_argument('--start_test_step', type=int,default=1)
argparser.add_argument('--init_linsvdo',type=float,default=-8.0,help='Param init value for linearSVDO') # sigma
argparser.add_argument('--svdo',default= 'true',action='store_true',help='Invoking this param implies using Sparse VD on MAML') # use SVD
argparser.add_argument('--custom',action='store_false',help='Switch off if you want to use models like vgg11_bn') # custom=False doesn't work yet
argparser.add_argument('--train', action='store_true', default=True)
argparser.add_argument('--resume',action='store_true',default=False)
args = argparser.parse_args()
main(args)
?啥也不懂就问,天下没有免费的午餐,这里是解决问题,不是教学地点