Shortcuts

Source code for haiscale.ddp.ddp

from typing import List
import os
import uuid
from contextlib import contextmanager

import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as dist_c10d
from torch.distributed.distributed_c10d import _get_default_group
from torch._utils import _get_device_index

import hfreduce
import hfreduce.torch as hfr


[docs]class DistributedDataParallel(torch.nn.Module): """ 分布式数据并行工具,封装了 hfreduce 使用幻方 AI 自研的 `hfreduce <https://www.high-flyer.cn/blog/hf-reduce/>`_ 多卡通信工具,可以替换 PyTorch DDP 加速训练。 使用方法与 ``torch.nn.parallel.DistributedDataParallel`` 相同。 Args: module (torch.nn.Module): PyTorch 模型 device_ids (list): 模型所在的 GPU id,如果是 None 则会用 `torch.cuda.current_device()` 的返回值,默认是 ``None`` broadcast_buffers (bool): 是否在 forward 之前把 rank-0 上的 buffer 广播到其他 rank 上,默认是 ``True`` process_group (ProcessGroup): ProcessGroup 对象,如果是 ``None`` 会用默认的分组 find_unused_parameters (bool): 是否遍历计算图,找到不参与 backward 的参数; 在少数情况下(比如训练 GAN)需要设置成 ``True``;默认是 ``False`` NOTE: 进程退出时可能会报错,可通过退出之前调用 ``model.reducer.stop()`` 来解决 Examples: .. code-block:: python from haiscale.ddp import DistributedDataParallel model = DistributedDataParallel(model, device_ids=[local_rank]) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # training ... for epoch in range(epochs): for step, (x, y) in enumerate(dataloader): # training optimizer.zero_grad() output = model(x) loss_fn(y, output).backward() optimizer.step() """ def __init__(self, module: torch.nn.Module, device_ids: List[int] = None, broadcast_buffers=True, process_group=None, find_unused_parameters=False): super().__init__() if not dist.is_initialized(): raise RuntimeError("torch.distributed is not initialized yet.") if not torch.cuda.is_available(): raise RuntimeError("No available GPUs!") if device_ids is not None and len(device_ids) > 1: raise ValueError("device_ids can only be None or contain a single element.") # 获取 ip 和 port # 对于 env 协议,我们可以从环境变量中获得 ip 和 port # 对于 tcp 协议,从 url 中解析获得 ip 和 port # file:// 协议不支持 :( url = dist_c10d._default_pg_init_method if url == "env://": ip = os.getenv("MASTER_ADDR") port = os.getenv("MASTER_PORT") elif url.startswith("tcp://"): url = url.split("://")[1] # tcp://ip:port ip, port = url.split(":") else: raise ValueError(f"Only env:// and tcp:// protocols are supported, but given {url}") # ranks_per_node if device_ids: device = _get_device_index(device_ids[0]) else: # 假设 torch.cuda.set_device() 已经被调用过了 device = torch.cuda.current_device() if process_group is None: process_group = _get_default_group() self.process_group = process_group self.find_unused_parameters = find_unused_parameters world_size = dist.get_world_size(group=process_group) node = uuid.getnode() info = torch.tensor([node, device], dtype=torch.int64, device=device) infos = [torch.empty_like(info) for _ in range(world_size)] dist.all_gather(infos, info, group=process_group) infos = [info.tolist() for info in infos] # num_nodes, ranks_per_node num_nodes = len(set(info[0] for info in infos)) assert world_size % num_nodes == 0 ranks_per_node = world_size // num_nodes # node_rank node_rank = 0 seen = set() for n, _ in infos: if n == node: break if n not in seen: node_rank += 1 seen.add(n) # local_rank local_rank = 0 seen = set() for n, d in infos: if n == node: if d == device: break assert d not in seen, f"Duplicated device {d} in node {node_rank}" seen.add(d) local_rank += 1 # TODO: group id group_id = infos[0][1] port = int(port) + 1 + group_id # avoid conflict # 初始化 hfreduce if not next(module.parameters()).is_cuda: module = module.cuda(device) self.module = module # hfreduce>=1.3.1 需要设置 sync_backward=True self.reducer = hfr.AsyncReduceFloat( ip, port, local_rank, ranks_per_node, node_rank, num_nodes, module, sync_backward=True) # follow pytorch DDP self.device = local_rank self.broadcast_buffers = broadcast_buffers if hasattr(module, "_ddp_params_and_buffers_to_ignore"): self.parameters_to_ignore = module._ddp_params_and_buffers_to_ignore else: self.parameters_to_ignore = [] # 检查是否有 UninitializedParameter 或者参数是否在 同一个 GPU 上 param0_device = next(module.parameters()).device for param in module.parameters(): if isinstance(param, torch.nn.parameter.UninitializedParameter): raise TypeError("torch.nn.parameter.UninitializedParameter is not supported") if param.device != param0_device: raise RuntimeError("All parameters must be on the same device") if param.dtype != torch.float32: raise RuntimeError("Only torch.float32 is supported now.") # broadcast buffer 的大小,默认是 250 MiB self.broadcast_bucket_size = int(250 * 1024 * 1024) # 如果用户没有固定随机种子的话,每个 rank 上的模型初始参数可能会不一样 # 把 rank0 的模型 broadcast 到其他 rank 上,保证初始模型的参数一致 self._sync_params_and_buffers() def forward(self, *args, **kwargs): # 在 forward 之前先 broadcast 一下 buffer if torch.is_grad_enabled() and self.broadcast_buffers: self._sync_buffers() outputs = self.module(*args, **kwargs) if torch.is_grad_enabled() and self.reducer.sync_grad and self.find_unused_parameters: self.reducer.filter_unused_parameters(outputs) return outputs
[docs] @contextmanager def no_sync(self): """ 有两种情况不需要对梯度做 allreduce: 1. torch.is_grad_enabled 为 False:这样 forward 的时候不产生梯度, 通常来说也不用做 backward #. 做多次 forward-backward, 把梯度进行累加。这样只有最后一次 backward 需要对梯度进行同步 该方法主要是考虑到了第二种情况,梯度和 buffer 都不会进行同步 Example: .. code-block:: python ddp = haiscale.ddp.DistributedDataParallel(model, ...) with ddp.no_sync(): for input in inputs: ddp(input).backward() # no synchronization, accumulate grads ddp(another_input).backward() # synchronize grads """ old_sync_grad = self.reducer.sync_grad old_broadcast_buffers = self.broadcast_buffers self.reducer.sync_grad = False self.broadcast_buffers = False try: yield finally: self.reducer.sync_grad = old_sync_grad self.broadcast_buffers = old_broadcast_buffers
def _sync_params_and_buffers(self): module_states = [] for name, param in self.module.named_parameters(): if name not in self.parameters_to_ignore: module_states.append(param.detach()) for name, buffer in self.module.named_buffers(): if name not in self.parameters_to_ignore: module_states.append(buffer.detach()) if len(module_states) > 0: dist._broadcast_coalesced(self.process_group, module_states, self.broadcast_bucket_size, 0) def _sync_buffers(self): buffers = [] for name, buffer in self.module.named_buffers(): if name not in self.parameters_to_ignore: buffers.append(buffer.detach()) if len(buffers) > 0: dist._broadcast_coalesced(self.process_group, buffers, self.broadcast_bucket_size, 0)