Source code for hfai.distributed.pairwise
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group, _get_global_rank
import warnings
with warnings.catch_warnings(record=True) as w:
from cupy.cuda import nccl # 防止在 202105 里报 warning
TORCH_TO_NCCL_TYPE = {
torch.float16: nccl.NCCL_FLOAT16,
torch.float32: nccl.NCCL_FLOAT32,
torch.float64: nccl.NCCL_FLOAT64,
torch.int8: nccl.NCCL_INT8,
torch.int32: nccl.NCCL_INT32,
torch.int64: nccl.NCCL_INT64,
torch.uint8: nccl.NCCL_UINT8,
}
GROUP_TO_NCCL_COMM = {}
def nccl_datatype(dtype):
nccl_type = TORCH_TO_NCCL_TYPE.get(dtype)
assert nccl_type is not None, f"不支持的类型:{dtype}"
return nccl_type
def nccl_comm(group, device):
comm = GROUP_TO_NCCL_COMM.get(group)
if comm is not None:
return comm
rank = dist.get_rank(group) # group's local rank
nranks = dist.get_world_size(group)
nccl_id = torch.tensor(nccl.get_unique_id(), dtype=torch.int32, device=device)
root = _get_global_rank(group, 0)
dist.broadcast(nccl_id, src=root, group=group)
nccl_id = tuple(nccl_id.tolist())
comm = nccl.NcclCommunicator(nranks, nccl_id, rank)
GROUP_TO_NCCL_COMM[group] = comm
return comm
def nccl_sendrecv(send_buf, recv_buf, send_to, recv_from, group):
"""
send_to, recv_from: group's local rank
"""
device = send_buf.device
comm = nccl_comm(group, device)
stream = torch.cuda.current_stream().cuda_stream
send_dtype = nccl_datatype(send_buf.dtype)
recv_dtype = nccl_datatype(recv_buf.dtype)
nccl.groupStart()
comm.send(send_buf.data_ptr(), send_buf.numel(), send_dtype, send_to, stream)
comm.recv(recv_buf.data_ptr(), recv_buf.numel(), recv_dtype, recv_from, stream)
nccl.groupEnd()
def torch_sendrecv(send_buf, recv_buf, send_to, recv_from, group):
"""
send_to, recv_from: group's local rank
"""
assert torch.__version__ >= "1.12.0" # TODO: fixme
send_to = _get_global_rank(group, send_to)
recv_from = _get_global_rank(group, recv_from)
send_op = dist.P2POp(dist.isend, send_buf, send_to, group=group)
recv_op = dist.P2POp(dist.irecv, recv_buf, recv_from, group=group)
reqs = dist.batch_isend_irecv([send_op, recv_op])
for req in reqs:
req.wait()
[docs]@torch.no_grad()
def pairwise_apply(f, x, args=(), group=None, equal_size=False):
"""
给定一个函数 ``f`` 和当前(第 ``i`` 个)GPU 的输入 ``x_i``,返回 ``[f(x_i, y_j) for j in range(nranks)]``, 其中 ``x_j`` 为第 ``j`` 个 GPU 的输入
返回的结果等价于以下实现:
.. code-block:: python
def pairwise_apply(f, x, args=(), group=None):
nranks = dist.get_world_size(group)
xs = [torch.empty_like(x) for _ in range(nranks)]
dist.all_gather(xs, x, group=group)
results = [f(x, xs[i], *args) for i in range(nranks)]
return results
Args:
f (Callable[Tensor, Tensor]): 需要调用的函数,通过 ``f(x, y, *args)`` 的方式调用
x (torch.Tensor): 输入的 tensor;每块 GPU 上 tensor 的形状可以不相同,但维度数量要相同
args (tuple): 需要额外传入 ``f`` 的参数
group (ProcessGroup): ProcessGroup 对象;默认是 ``None``
equal_size (bool): 每块 GPU 上输入的 tensor 形状是否相同;默认是 ``False``
Returns:
``f`` 作用在每块 GPU 上的结果
.. note::
1) 当 torch 的版本小于 1.12.0 时,本函数是通过 cupy 调用底层 nccl 实现的,第一次调用本函数的时候需要等待一小段初始化 NCCL 的时间。
2) 每块 GPU 上的 tensor 小于 1 MiB 时,通讯的性能会有所下降。
Examples:
>>> import hfai.distributed as dist
>>> def f(x, y):
... return x + y
>>> rank = dist.get_rank()
>>> x = torch.ones(1, device="cuda") * rank
>>> dist.pairwise_apply(f, x)
[tensor(0), tensor(1), tensor(2), tensor(3)] # Rank 0
[tensor(1), tensor(2), tensor(3), tensor(4)] # Rank 1
[tensor(2), tensor(3), tensor(4), tensor(5)] # Rank 2
[tensor(3), tensor(4), tensor(5), tensor(6)] # Rank 3
"""
if group is None:
group = _get_default_group()
rank = dist.get_rank(group) # group's local rank
nranks = dist.get_world_size(group)
# 假设所有 tensor 的维度是一样
if equal_size:
shapes = [x.shape for _ in range(nranks)]
else:
shape = torch.tensor(x.shape, dtype=torch.int32, device=x.device)
shapes = [torch.empty_like(shape) for _ in range(nranks)]
dist.all_gather(shapes, shape, group=group)
shapes = [torch.Size(s) for s in shapes]
max_size = max(s.numel() for s in shapes)
send_buf = x.new_empty(max_size)
recv_buf = x.new_empty(max_size)
prev_rank = (rank - 1 + nranks) % nranks
next_rank = (rank + 1) % nranks
# fill in send buffer
send_buf[:x.numel()].copy_(x.view(-1))
results = [None for _ in range(nranks)]
results[rank] = f(x, x, *args)
for i in range(nranks - 1):
recv_idx = (rank - i - 1 + nranks) % nranks
nccl_sendrecv(send_buf, recv_buf, next_rank, prev_rank, group)
shape = shapes[recv_idx]
y = recv_buf[:shape.numel()].view(shape)
results[recv_idx] = f(x, y, *args)
# 把接受到的 buf 传给下一个 rank
send_buf, recv_buf = recv_buf, send_buf
return results