Source code for haiscale.ddp.ddp
from typing import List
import os
import uuid
from contextlib import contextmanager
import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as dist_c10d
from torch.distributed.distributed_c10d import _get_default_group
from torch._utils import _get_device_index
import hfreduce
import hfreduce.torch as hfr
[docs]class DistributedDataParallel(torch.nn.Module):
"""
分布式数据并行工具,封装了 hfreduce
使用幻方 AI 自研的 `hfreduce <https://www.high-flyer.cn/blog/hf-reduce/>`_ 多卡通信工具,可以替换 PyTorch DDP 加速训练。
使用方法与 ``torch.nn.parallel.DistributedDataParallel`` 相同。
Args:
module (torch.nn.Module): PyTorch 模型
device_ids (list): 模型所在的 GPU id,如果是 None 则会用 `torch.cuda.current_device()` 的返回值,默认是 ``None``
broadcast_buffers (bool): 是否在 forward 之前把 rank-0 上的 buffer 广播到其他 rank 上,默认是 ``True``
process_group (ProcessGroup): ProcessGroup 对象,如果是 ``None`` 会用默认的分组
find_unused_parameters (bool): 是否遍历计算图,找到不参与 backward 的参数;
在少数情况下(比如训练 GAN)需要设置成 ``True``;默认是 ``False``
NOTE:
进程退出时可能会报错,可通过退出之前调用 ``model.reducer.stop()`` 来解决
Examples:
.. code-block:: python
from haiscale.ddp import DistributedDataParallel
model = DistributedDataParallel(model, device_ids=[local_rank])
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# training ...
for epoch in range(epochs):
for step, (x, y) in enumerate(dataloader):
# training
optimizer.zero_grad()
output = model(x)
loss_fn(y, output).backward()
optimizer.step()
"""
def __init__(self, module: torch.nn.Module, device_ids: List[int] = None, broadcast_buffers=True,
process_group=None, find_unused_parameters=False):
super().__init__()
if not dist.is_initialized():
raise RuntimeError("torch.distributed is not initialized yet.")
if not torch.cuda.is_available():
raise RuntimeError("No available GPUs!")
if device_ids is not None and len(device_ids) > 1:
raise ValueError("device_ids can only be None or contain a single element.")
# 获取 ip 和 port
# 对于 env 协议,我们可以从环境变量中获得 ip 和 port
# 对于 tcp 协议,从 url 中解析获得 ip 和 port
# file:// 协议不支持 :(
url = dist_c10d._default_pg_init_method
if url == "env://":
ip = os.getenv("MASTER_ADDR")
port = os.getenv("MASTER_PORT")
elif url.startswith("tcp://"):
url = url.split("://")[1] # tcp://ip:port
ip, port = url.split(":")
else:
raise ValueError(f"Only env:// and tcp:// protocols are supported, but given {url}")
# ranks_per_node
if device_ids:
device = _get_device_index(device_ids[0])
else:
# 假设 torch.cuda.set_device() 已经被调用过了
device = torch.cuda.current_device()
if process_group is None:
process_group = _get_default_group()
self.process_group = process_group
self.find_unused_parameters = find_unused_parameters
world_size = dist.get_world_size(group=process_group)
node = uuid.getnode()
info = torch.tensor([node, device], dtype=torch.int64, device=device)
infos = [torch.empty_like(info) for _ in range(world_size)]
dist.all_gather(infos, info, group=process_group)
infos = [info.tolist() for info in infos]
# num_nodes, ranks_per_node
num_nodes = len(set(info[0] for info in infos))
assert world_size % num_nodes == 0
ranks_per_node = world_size // num_nodes
# node_rank
node_rank = 0
seen = set()
for n, _ in infos:
if n == node:
break
if n not in seen:
node_rank += 1
seen.add(n)
# local_rank
local_rank = 0
seen = set()
for n, d in infos:
if n == node:
if d == device:
break
assert d not in seen, f"Duplicated device {d} in node {node_rank}"
seen.add(d)
local_rank += 1
# TODO: group id
group_id = infos[0][1]
port = int(port) + 1 + group_id # avoid conflict
# 初始化 hfreduce
if not next(module.parameters()).is_cuda:
module = module.cuda(device)
self.module = module
# hfreduce>=1.3.1 需要设置 sync_backward=True
self.reducer = hfr.AsyncReduceFloat(
ip, port, local_rank, ranks_per_node, node_rank,
num_nodes, module, sync_backward=True)
# follow pytorch DDP
self.device = local_rank
self.broadcast_buffers = broadcast_buffers
if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
self.parameters_to_ignore = module._ddp_params_and_buffers_to_ignore
else:
self.parameters_to_ignore = []
# 检查是否有 UninitializedParameter 或者参数是否在 同一个 GPU 上
param0_device = next(module.parameters()).device
for param in module.parameters():
if isinstance(param, torch.nn.parameter.UninitializedParameter):
raise TypeError("torch.nn.parameter.UninitializedParameter is not supported")
if param.device != param0_device:
raise RuntimeError("All parameters must be on the same device")
if param.dtype != torch.float32:
raise RuntimeError("Only torch.float32 is supported now.")
# broadcast buffer 的大小,默认是 250 MiB
self.broadcast_bucket_size = int(250 * 1024 * 1024)
# 如果用户没有固定随机种子的话,每个 rank 上的模型初始参数可能会不一样
# 把 rank0 的模型 broadcast 到其他 rank 上,保证初始模型的参数一致
self._sync_params_and_buffers()
def forward(self, *args, **kwargs):
# 在 forward 之前先 broadcast 一下 buffer
if torch.is_grad_enabled() and self.broadcast_buffers:
self._sync_buffers()
outputs = self.module(*args, **kwargs)
if torch.is_grad_enabled() and self.reducer.sync_grad and self.find_unused_parameters:
self.reducer.filter_unused_parameters(outputs)
return outputs
[docs] @contextmanager
def no_sync(self):
"""
有两种情况不需要对梯度做 allreduce:
1. torch.is_grad_enabled 为 False:这样 forward 的时候不产生梯度, 通常来说也不用做 backward
#. 做多次 forward-backward, 把梯度进行累加。这样只有最后一次 backward 需要对梯度进行同步
该方法主要是考虑到了第二种情况,梯度和 buffer 都不会进行同步
Example:
.. code-block:: python
ddp = haiscale.ddp.DistributedDataParallel(model, ...)
with ddp.no_sync():
for input in inputs:
ddp(input).backward() # no synchronization, accumulate grads
ddp(another_input).backward() # synchronize grads
"""
old_sync_grad = self.reducer.sync_grad
old_broadcast_buffers = self.broadcast_buffers
self.reducer.sync_grad = False
self.broadcast_buffers = False
try:
yield
finally:
self.reducer.sync_grad = old_sync_grad
self.broadcast_buffers = old_broadcast_buffers
def _sync_params_and_buffers(self):
module_states = []
for name, param in self.module.named_parameters():
if name not in self.parameters_to_ignore:
module_states.append(param.detach())
for name, buffer in self.module.named_buffers():
if name not in self.parameters_to_ignore:
module_states.append(buffer.detach())
if len(module_states) > 0:
dist._broadcast_coalesced(self.process_group, module_states, self.broadcast_bucket_size, 0)
def _sync_buffers(self):
buffers = []
for name, buffer in self.module.named_buffers():
if name not in self.parameters_to_ignore:
buffers.append(buffer.detach())
if len(buffers) > 0:
dist._broadcast_coalesced(self.process_group, buffers, self.broadcast_bucket_size, 0)