BalancedDataParallel 在后续调用model时现实model不存在

问题遇到的现象和发生背景

想要平衡pytorch在训练时显存分布不均匀的问题所以才用了BalancedDataParallel,然后是到git上去复制了一份data_parallel的代码,如下

用代码块功能插入代码,请勿粘贴截图

from torch.nn.parallel import DataParallel
import torch
from torch.nn.parallel._functions import Scatter
from torch.nn.parallel.parallel_apply import parallel_apply

def scatter(inputs, target_gpus, chunk_sizes, dim=0):
    r"""
    Slices tensors into approximately equal chunks and
    distributes them across given GPUs. Duplicates
    references to objects that are not tensors.
    """
    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            try:
                return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
            except:
                print('obj', obj.size())
                print('dim', dim)
                print('chunk_sizes', chunk_sizes)
                quit()
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return list(map(list, zip(*map(scatter_map, obj))))
        if isinstance(obj, dict) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
        return [obj for targets in target_gpus]

    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has references
    # to a closure that has a reference to the scatter_map cell (because the
    # fn is recursive). To avoid this reference cycle, we set the function to
    # None, clearing the cell
    try:
        return scatter_map(inputs)
    finally:
        scatter_map = None

def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
    r"""Scatter with support for kwargs dictionary"""
    inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
    kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
    if len(inputs) < len(kwargs):
        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
    elif len(kwargs) < len(inputs):
        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
    inputs = tuple(inputs)
    kwargs = tuple(kwargs)
    return inputs, kwargs

class BalancedDataParallel(DataParallel):
    def __init__(self, gpu0_bsz, *args, **kwargs):
        self.gpu0_bsz = gpu0_bsz
        super().__init__(*args, **kwargs)

    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)
        if self.gpu0_bsz == 0:
            device_ids = self.device_ids[1:]
        else:
            device_ids = self.device_ids
        inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
        # print('len(inputs)1: ', str(len(inputs)))
        # print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)]))
        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        if self.gpu0_bsz == 0:
            replicas = replicas[1:]
        outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
        return self.gather(outputs, self.output_device)

    def parallel_apply(self, replicas, device_ids, inputs, kwargs):
        return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])

    def scatter(self, inputs, kwargs, device_ids):
        bsz = inputs[0].size(self.dim)
        num_dev = len(self.device_ids)
        gpu0_bsz = self.gpu0_bsz
        bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
        if gpu0_bsz < bsz_unit:
            chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
            delta = bsz - sum(chunk_sizes)
            for i in range(delta):
                chunk_sizes[i + 1] += 1
            if gpu0_bsz == 0:
                chunk_sizes = chunk_sizes[1:]
        else:
            return super().scatter(inputs, kwargs, device_ids)
        return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
运行结果及报错内容
Traceback (most recent call last):
  File "train_cstrack.py", line 523, in 
    train(hyp, opt, device, tb_writer)
  File "train_cstrack.py", line 329, in train
    loss, loss_items = mot_loss(pred, targets.to(device), model)  # scaled by batch_size
  File "/data/amax/b510/cs/.conda/envs/CSTrack/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/amax/b510/cs/SOTS/CSTrack/tracking/../lib/core/mot/base_trainer.py", line 319, in forward
    tcls, tbox, indices, anchors, indices_id, tids = build_targets(p, targets, model)  # targets
  File "/data/amax/b510/cs/SOTS/CSTrack/tracking/../lib/core/mot/base_trainer.py", line 227, in build_targets
    det = model.module.model[-1] if is_parallel(model) else model.model[-1]  # Detect() module
  File "/data/amax/b510/cs/.conda/envs/CSTrack/lib/python3.8/site-packages/torch/nn/modules/module.py", line 771, in __getattr__
    raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
torch.nn.modules.module.ModuleAttributeError: 'BalancedDataParallel' object has no attribute 'model'

我的解答思路和尝试过的方法

换了几个代码和位置还是一样抱错,也有看到说删掉*.pyc文件也没有用

我想要达到的结果

希望能在平衡状态下正常运行本代码,因为第一张卡爆了,第二张卡才40%

希望有用
https://b23.tv/9ucseej

参考链接

THINKPHP5 如何在 控制器内调用model模型_hexiaoniao的博客-CSDN博客_tp5控制器调用模型   注意控制器内要引入model:use app\index\model\RoleModel;调用model:$doctor = new Role();$doctor = new RoleModel();---这是错误的!!!!新建模型。role对应数据库role表:或者直接调用$user = new \app\index\model\Role... https://blog.csdn.net/hexiaoniao/article/details/84543676?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522166735255916782427426217%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fblog.%2522%257D&request_id=166735255916782427426217&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~blog~first_rank_ecpm_v1~rank_v31_ecpm-4-84543676-null-null.nonecase&utm_term=BalancedDataParallel%20%E5%9C%A8%E5%90%8E%E7%BB%AD%E8%B0%83%E7%94%A8model%E6%97%B6%E7%8E%B0%E5%AE%9Emodel%E4%B8%8D%E5%AD%98%E5%9C%A8&spm=1018.2226.3001.4450

pytorch加载nn.DataParallel训练的模型出现的问题_景唯acr的博客-CSDN博客 nn.DataParallel分布式训练后,如果直接使用torch.save(model.state_dict(), model_out_path)保存模型,等到再加载模型,可以将需要加载模型的网络也弄成分布式训练。分布式训练时,最好使用torch.save(model.module.state_dict(), model_out_path)保存模型,这样等到需要测试网络时,加载模型时用model.load_state_dict(torch.load(PATH, map_location=device)) https://blog.csdn.net/weixin_41735859/article/details/108610687?ops_request_misc=&request_id=&biz_id=102&utm_term=BalancedDataParallel%20%E5%9C%A8%E5%90%8E%E7%BB%AD%E8%B0%83%E7%94%A8mode&utm_medium=distribute.pc_search_result.none-task-blog-2~blog~sobaiduweb~default-2-108610687.pc_ask&spm=1018.2226.3001.4450