Shortcuts

Source code for hfai.distributed.pairwise

import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group, _get_global_rank

import warnings
with warnings.catch_warnings(record=True) as w:
    from cupy.cuda import nccl  # 防止在 202105 里报 warning


TORCH_TO_NCCL_TYPE = {
    torch.float16: nccl.NCCL_FLOAT16,
    torch.float32: nccl.NCCL_FLOAT32,
    torch.float64: nccl.NCCL_FLOAT64,
    torch.int8: nccl.NCCL_INT8,
    torch.int32: nccl.NCCL_INT32,
    torch.int64: nccl.NCCL_INT64,
    torch.uint8: nccl.NCCL_UINT8,
}

GROUP_TO_NCCL_COMM = {}


def nccl_datatype(dtype):
    nccl_type = TORCH_TO_NCCL_TYPE.get(dtype)
    assert nccl_type is not None, f"不支持的类型:{dtype}"
    return nccl_type


def nccl_comm(group, device):
    comm = GROUP_TO_NCCL_COMM.get(group)
    if comm is not None:
        return comm

    rank = dist.get_rank(group)  # group's local rank
    nranks = dist.get_world_size(group)

    nccl_id = torch.tensor(nccl.get_unique_id(), dtype=torch.int32, device=device)
    root = _get_global_rank(group, 0)
    dist.broadcast(nccl_id, src=root, group=group)
    nccl_id = tuple(nccl_id.tolist())

    comm = nccl.NcclCommunicator(nranks, nccl_id, rank)
    GROUP_TO_NCCL_COMM[group] = comm

    return comm


def nccl_sendrecv(send_buf, recv_buf, send_to, recv_from, group):
    """
    send_to, recv_from: group's local rank
    """
    device = send_buf.device
    comm = nccl_comm(group, device)

    stream = torch.cuda.current_stream().cuda_stream
    send_dtype = nccl_datatype(send_buf.dtype)
    recv_dtype = nccl_datatype(recv_buf.dtype)

    nccl.groupStart()
    comm.send(send_buf.data_ptr(), send_buf.numel(), send_dtype, send_to, stream)
    comm.recv(recv_buf.data_ptr(), recv_buf.numel(), recv_dtype, recv_from, stream)
    nccl.groupEnd()


def torch_sendrecv(send_buf, recv_buf, send_to, recv_from, group):
    """
    send_to, recv_from: group's local rank
    """
    assert torch.__version__ >= "1.12.0"  # TODO: fixme

    send_to = _get_global_rank(group, send_to)
    recv_from = _get_global_rank(group, recv_from)

    send_op = dist.P2POp(dist.isend, send_buf, send_to, group=group)
    recv_op = dist.P2POp(dist.irecv, recv_buf, recv_from, group=group)
    reqs = dist.batch_isend_irecv([send_op, recv_op])
    for req in reqs:
        req.wait()


[docs]@torch.no_grad() def pairwise_apply(f, x, args=(), group=None, equal_size=False): """ 给定一个函数 ``f`` 和当前(第 ``i`` 个)GPU 的输入 ``x_i``,返回 ``[f(x_i, y_j) for j in range(nranks)]``, 其中 ``x_j`` 为第 ``j`` 个 GPU 的输入 返回的结果等价于以下实现: .. code-block:: python def pairwise_apply(f, x, args=(), group=None): nranks = dist.get_world_size(group) xs = [torch.empty_like(x) for _ in range(nranks)] dist.all_gather(xs, x, group=group) results = [f(x, xs[i], *args) for i in range(nranks)] return results Args: f (Callable[Tensor, Tensor]): 需要调用的函数,通过 ``f(x, y, *args)`` 的方式调用 x (torch.Tensor): 输入的 tensor;每块 GPU 上 tensor 的形状可以不相同,但维度数量要相同 args (tuple): 需要额外传入 ``f`` 的参数 group (ProcessGroup): ProcessGroup 对象;默认是 ``None`` equal_size (bool): 每块 GPU 上输入的 tensor 形状是否相同;默认是 ``False`` Returns: ``f`` 作用在每块 GPU 上的结果 .. note:: 1) 当 torch 的版本小于 1.12.0 时,本函数是通过 cupy 调用底层 nccl 实现的,第一次调用本函数的时候需要等待一小段初始化 NCCL 的时间。 2) 每块 GPU 上的 tensor 小于 1 MiB 时,通讯的性能会有所下降。 Examples: >>> import hfai.distributed as dist >>> def f(x, y): ... return x + y >>> rank = dist.get_rank() >>> x = torch.ones(1, device="cuda") * rank >>> dist.pairwise_apply(f, x) [tensor(0), tensor(1), tensor(2), tensor(3)] # Rank 0 [tensor(1), tensor(2), tensor(3), tensor(4)] # Rank 1 [tensor(2), tensor(3), tensor(4), tensor(5)] # Rank 2 [tensor(3), tensor(4), tensor(5), tensor(6)] # Rank 3 """ if group is None: group = _get_default_group() rank = dist.get_rank(group) # group's local rank nranks = dist.get_world_size(group) # 假设所有 tensor 的维度是一样 if equal_size: shapes = [x.shape for _ in range(nranks)] else: shape = torch.tensor(x.shape, dtype=torch.int32, device=x.device) shapes = [torch.empty_like(shape) for _ in range(nranks)] dist.all_gather(shapes, shape, group=group) shapes = [torch.Size(s) for s in shapes] max_size = max(s.numel() for s in shapes) send_buf = x.new_empty(max_size) recv_buf = x.new_empty(max_size) prev_rank = (rank - 1 + nranks) % nranks next_rank = (rank + 1) % nranks # fill in send buffer send_buf[:x.numel()].copy_(x.view(-1)) results = [None for _ in range(nranks)] results[rank] = f(x, x, *args) for i in range(nranks - 1): recv_idx = (rank - i - 1 + nranks) % nranks nccl_sendrecv(send_buf, recv_buf, next_rank, prev_rank, group) shape = shapes[recv_idx] y = recv_buf[:shape.numel()].view(shape) results[recv_idx] = f(x, y, *args) # 把接受到的 buf 传给下一个 rank send_buf, recv_buf = recv_buf, send_buf return results