Source code for hfai.distributed.diverse
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_global_rank, _get_default_group
[docs]def all_gather(tensor_list, tensor, group=None, async_op=False):
"""
功能与 ``torch.distributed.all_gather`` 一样,但支持不同大小的 tensor。
Args:
tensor_list (list[Tensor]): gather 后的 tensors,每个 tensor 大小不一定要相同
tensor (Tensor): 当前 rank 要 broadcast 的 tensor
group (ProcessGroup, optional): ProcessGroup 对象. 如果是 ``None`` 会用默认的分组
async_op (bool, optional): 是否为异步的操作
NOTE:
tensor 大小不一样时不支持设置 ``async_op = True``
Examples:
>>> import hfai.distributed as dist
>>> tensor_list = [torch.zeros(2 + i, dtype=torch.int64) for i in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2 + rank, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4, 5]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
"""
sizes = set(x.size() for x in tensor_list)
equal_size = (len(sizes) <= 1)
if equal_size:
return dist.all_gather(tensor_list, tensor, group, async_op)
assert not async_op, "tensor 大小不一样时不支持 async_op = True"
group = group or _get_default_group()
rank = dist.get_rank(group=group)
world_size = dist.get_world_size(group=group)
assert len(tensor_list) == world_size
assert tensor_list[rank].size() == tensor.size()
if group is None or group is _get_default_group():
global_ranks = list(range(world_size))
else:
global_ranks = [_get_global_rank(group, i) for i in range(world_size)]
# gather all tensors
tensor_list[rank].copy_(tensor)
for i, r in enumerate(global_ranks):
dist.broadcast(tensor_list[i], r, group=group)
[docs]def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, async_op=False):
"""
功能与 ``torch.distributed.reduce_scatter`` 一样,但支持不同大小的 tensor。
Args:
output (Tensor): 输出的 tensor
input_list (list[Tensor]): 准备做 reduce scatter 的 tensors,每个 tensor 大小不一定要相同
group (ProcessGroup, optional): ProcessGroup 对象. 如果是 ``None `` 会用默认的分组
async_op (bool, optional): 是否为异步的操作
NOTE:
tensor 大小不一样时不支持设置 ``async_op = True``
"""
sizes = set(x.size() for x in input_list)
equal_size = (len(sizes) <= 1)
if equal_size:
return dist.reduce_scatter(output, input_list, op, group, async_op)
assert not async_op, "tensor 大小不一样时不支持 async_op = True"
group = group or _get_default_group()
rank = dist.get_rank(group=group)
world_size = dist.get_world_size(group=group)
assert len(input_list) == world_size
assert input_list[rank].size() == output.size()
if group is None or group is _get_default_group():
global_ranks = list(range(world_size))
else:
global_ranks = [_get_global_rank(group, i) for i in range(world_size)]
# reduce chunk-i to rank-i
for i, r in enumerate(global_ranks):
dist.reduce(input_list[i], r, op=op, group=group)
output.data.copy_(input_list[rank])