请问如何把这个pytorch代码改成处理batch的

我在github上找到一个开源的用pytorch写的 用FastCMA-ES求解TSP问题的代码
但他目前只能求解一个实例,我想问一下,可以把他改成求解batch个实例的代码嘛
也就是说 输入从(n,2)改成(batch_size,n,2)

github链接是https://github.com/kayuksel/torch-tsp-es

代码是:

import torch
from math import log


def reward_func(sol, dist_mat):
     # sol.shape:samples_size * n
     samples_size = len(sol)
     rews = torch.zeros(samples_size).cuda()  # sample_size
     for i, row in enumerate(sol.argsort()):
         a = row 
         b = torch.cat((a[1:], a[0].unsqueeze(0)))  
         rews[i] = dist_mat[a, b].sum() 

     return rews


class FastCMA(object):
     def __init__(self, N, samples):
         self.samples = samples
         mu = samples // 2
         self.weights = torch.tensor([log(mu + 0.5)]).cuda()
         self.weights = self.weights - torch.linspace(
             start=1, end=mu, steps=mu).cuda().log()
         self.weights /= self.weights.sum()
  
         self.mueff = (self.weights.sum() ** 2 / (self.weights ** 2).sum()).item()

         # settings
         self.cc = (4 + self.mueff / N) / (N + 4 + 2 * self.mueff / N)
         self.c1 = 2 / ((N + 1.3) ** 2 + self.mueff)
         self.cmu = 2 * (self.mueff - 2 + 1 / self.mueff)
         self.cmu /= ((N + 2) ** 2 + 2 * self.mueff / 2)

         # variables 
         self.mean = torch.zeros(N).cuda()
         self.b = torch.eye(N).cuda()
         self.d = self.b.clone()
         bd = self.b * self.d
         self.c = bd * bd.T
         self.pc = self.mean.clone()

    
     def step(self, objective_f, dist_mat, step_size):
        
         z = torch.randn(self.mean.size(0), self.samples).cuda()
         s = self.mean.view(-1, 1) + step_size * self.b.matmul(self.d.matmul(z))
 
         results = [{'parameters': s.T[i], 'z': z.T[i],
                     'fitness': f.item()} for i, f in enumerate(objective_f(s.T,dist_mat))]

   
         ranked_results = sorted(results, key=lambda x: x['fitness'])
         selected_results = ranked_results[0:self.samples // 2]
         z = torch.stack([g['z'] for g in selected_results])
         g = torch.stack([g['parameters'] for g in selected_results])

         self.mean = (g * self.weights.unsqueeze(1)).sum(0)
         zmean = (z * self.weights.unsqueeze(1)).sum(0)
         self.pc *= (1 - self.cc)
         pc_cov = self.pc.unsqueeze(1) * self.pc.unsqueeze(1).T
         pc_cov = pc_cov + self.cc * (2 - self.cc) * self.c

         bdz = self.b.matmul(self.d).matmul(z.T)
         cmu_cov = bdz.matmul(self.weights.diag_embed())
         cmu_cov = cmu_cov.matmul(bdz.T)

         self.c *= (1 - self.c1 - self.cmu)
         self.c += (self.c1 * pc_cov) + (self.cmu * cmu_cov)
         self.d, self.b = torch.linalg.eigh(self.c, UPLO='U')
         self.d = self.d.sqrt().diag_embed()
         return ranked_results

def FastCMA_ES(dist_mat,step_size,sample_size,max_epochs):
     best_reward = None
     best_route = None
     n,_ = dist_mat.size()

     with torch.no_grad():
         cma_es = FastCMA(N=n, samples=sample_size)  
         for epoch in range(max_epochs):
             try:
                 res = cma_es.step(objective_f=reward_func,dist_mat=dist_mat,step_size=step_size)
             except Exception as e:
                 print(e)
                 break
             if best_reward is None: best_reward = res[0]['fitness']
             if res[0]['fitness'] < best_reward:
                 best_reward = res[0]['fitness']
                 # print("%i %f" % (epoch, best_reward))
                 best_route = res[0]['parameters'].argsort()
                 # print('cmes y:{}'.format(best_route))
     return best_route,best_reward

n=10
points = torch.rand(n, 2)
dist_mat = torch.cdist(points, points)
last_route,last_reward = FastCMA_ES(dist_mat,0.5,512,1000)
 # print('last_route:{}'.format(last_route))

要将此代码修改为可以处理batch的代码有以下步骤:

  1. 将 reward_func 函数中的第一行修改为 samples_size, n = sol.size(),以便可以接受不同形状的数据。
  2. 将 FastCMA_ES 函数中的 best_reward,best_route,n以及 points 改为适应新输入的参数。也就是将 n=10 改为输入数据的大小,将 points 改为输入数据,将 best_route 和 best_reward 改为一个大小为 batch_size 的数组(best_route = [None] * batch_size,best_reward = [None] * batch_size)。
  3. 相应地修改所有引用 n 的地方(如 cma_es = FastCMA(N=n, samples=sample_size))。
  4. 修改 FastCMA 类的初始化方法中用于计算 mueff 的代码,使其适应新的样本数。即将 self.mueff = (self.weights.sum() ** 2 / (self.weights ** 2). sum()).item() 改为 self.mueff = (self.weights.sum(dim=0) ** 2 / (self.weights ** 2). sum(dim=0))..view(-1, 1)。
  5. 对 FastCMA 类中所有引用 N 和 samples 的地方进行修改,使其在计算时使用正确的维度。例如,将 z = torch.randn(self.mean.size(0), self.samples).cuda() 改为 z = torch.randn(samples_size, n, 2).cuda()
  6. 在 FastCMA.step 方法中,将 结果 变量初始化为 结果 = [None] * samples_size,并使用 enumerate 遍历 sol 的第一个维度。
  7. 将 FastCMA.step 方法中通过枚举 sol.argsort() 获得特定行并计算奖励的代码修改为适应新的输入数据。即:将 a = rows 改为 a = sol[i],并将 b 转化为 [a[:, 1:], a[:, 0].unsqueeze(1)],再计算奖励。
  8. 最后,将 FastCMA_ES 函数中的 res = cma_es.step(objective_f=reward_func,dist_mat=dist_mat,step_size=step_size) 行改为 res = cma_es.step(objective_f=reward_func, sol=sol, dist_mat=dist_mat, step_size=step_size),并添加一个新的 sol = s.permute(1, 0, 2) 行,以便将数据重新排列为 batching 类型。
    如此一来,即可实现将输入从(n,2)改变为(batch_size,n,2)的功能。