DistributedDataParallel部署多GPU的问题

import os
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, HeteroConv, GlobalAttention
from htg_data import HTG_data
import torch.nn.functional as F

class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers=2, dropout=0.5):
        super().__init__()
        self.dropout = dropout
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('token', 'next_token', 'token'): GCNConv(-1, hidden_channels),
                ('token', 'token_sink', 'sink'): SAGEConv((-1, -1), hidden_channels),
                ('token', 'belongs_to', 'property'): GATConv((-1, -1), hidden_channels),
                ('property', 'property_sink', 'sink'): SAGEConv((-1, -1), hidden_channels),
                ('property', 'next_property', 'property'): GATConv((-1, -1), hidden_channels),
            }, aggr='sum')
            self.convs.append(conv)
        self.pooling_gate_nn = Linear(hidden_channels, 1)
        self.pooling = GlobalAttention(self.pooling_gate_nn)
        self.lin = Linear(hidden_channels, out_channels)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        self.pooling.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, x_dict, edge_index_dict, batch):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
        x = torch.cat((x_dict['sink'],x_dict['property'],x_dict['token']), 0)
        x = self.pooling(x, batch)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)
        return x

def run(rank, world_size: int, root: str):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

    dataset = HTG_data(root = root)

    print(dataset[0])

    train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    train_loader = DataLoader(dataset, batch_size=128, sampler=train_sampler)


    torch.manual_seed(12345)
    model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=2).to(rank)
    model = DistributedDataParallel(model, device_ids=[rank])

    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
    criterion = torch.nn.MultiLabelSoftMarginLoss()

    for epoch in range(1, 51):
        model.train()

        total_loss = 0
        for data in train_loader:
            data = data.to(rank)
            optimizer.zero_grad()
            logits = model(data.x_dict, data.edge_index_dict, data.batch)
            loss = criterion(logits, data.y.to(torch.float))
            loss.backward()
            optimizer.step()
            total_loss += float(loss) * logits.size(0)
        loss = total_loss / len(train_loader.dataset)

        dist.barrier()

        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ')

    dist.destroy_process_group()

if __name__ == '__main__':

    root = '/home/ylzqn/HTG_Data/HTG_CAG'
    world_size = torch.cuda.device_count()
    print('Let\'s use', world_size, 'GPUs!')
    args = (world_size, root)
    mp.spawn(run, args=args, nprocs=world_size, join=True)

**Traceback (most recent call last):
File "/home/ylzqn/Jupyter Notebook/pkgcode2vec/htg_model.py", line 127, in
mp.spawn(run, args=args, nprocs=world_size, join=True)
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
while not context.join():
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
fn(i, args)
File "/home/ylzqn/Jupyter Notebook/pkgcode2vec/htg_model.py", line 71, in run
model = DistributedDataParallel(model, device_ids=[rank])
File "/home/ylzqn/.conda/envs/lynch_pytorch/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 435, in init
'Modules with uninitialized parameters can't be used with DistributedDataParallel. '
RuntimeError: Modules with uninitialized parameters can't be used with DistributedDataParallel. Run a dummy forward pass to correctly initialize the modules
*


我想用DistributedDataParallel将模型部署到4个GPU,出现了以上的问题. 希望哪位能指导一下.

不知道你这个问题是否已经解决, 如果还没有解决的话:

如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^