AttributeError: module 'torch.distributed' has no attribute '_all_gather_base'
这个错误通常是因为 PyTorch 版本不兼容导致的。_all_gather_base
是 PyTorch 分布式训练中的一个函数,但是在某些版本中可能会被删除或更改。
尝试更新 PyTorch 到最新版本,或者降低 PyTorch 版本以解决此问题。你可以使用以下命令来安装特定版本的 PyTorch:
pip install torch==1.7.1
如果你使用的是 Conda 环境,则可以使用以下命令:
conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 -c pytorch
如果更新或降低 PyTorch 版本仍然无法解决问题,请检查你的代码是否正确导入了 torch.distributed
模块。
安装完apex后,调用的是时候出现如下错误:
File "/home/shuyuan/anaconda3/envs/shuyuan/lib/python3.8/site-packages/apex/transformer/pipeline_parallel/schedules/__init__.py", line 3, in <module>
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import (
File "/home/shuyuan/anaconda3/envs/shuyuan/lib/python3.8/site-packages/apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py", line 10, in <module>
from apex.transformer.pipeline_parallel.schedules.common import Batch
File "/home/shuyuan/anaconda3/envs/shuyuan/lib/python3.8/site-packages/apex/transformer/pipeline_parallel/schedules/common.py", line 9, in <module>
from apex.transformer.pipeline_parallel.p2p_communication import FutureTensor
File "/home/shuyuan/anaconda3/envs/shuyuan/lib/python3.8/site-packages/apex/transformer/pipeline_parallel/p2p_communication.py", line 25, in <module>
from apex.transformer.utils import split_tensor_into_1d_equal_chunks
File "/home/shuyuan/anaconda3/envs/shuyuan/lib/python3.8/site-packages/apex/transformer/utils.py", line 11, in <module>
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
AttributeError: module 'torch.distributed' has no attribute '_all_gather_base'