Shortcuts

Source code for hfai.distributed.diverse

import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_global_rank, _get_default_group


[docs]def all_gather(tensor_list, tensor, group=None, async_op=False): """ 功能与 ``torch.distributed.all_gather`` 一样,但支持不同大小的 tensor。 Args: tensor_list (list[Tensor]): gather 后的 tensors,每个 tensor 大小不一定要相同 tensor (Tensor): 当前 rank 要 broadcast 的 tensor group (ProcessGroup, optional): ProcessGroup 对象. 如果是 ``None`` 会用默认的分组 async_op (bool, optional): 是否为异步的操作 NOTE: tensor 大小不一样时不支持设置 ``async_op = True`` Examples: >>> import hfai.distributed as dist >>> tensor_list = [torch.zeros(2 + i, dtype=torch.int64) for i in range(2)] >>> tensor_list [tensor([0, 0]), tensor([0, 0, 0])] # Rank 0 and 1 >>> tensor = torch.arange(2 + rank, dtype=torch.int64) + 1 + 2 * rank >>> tensor tensor([1, 2]) # Rank 0 tensor([3, 4, 5]) # Rank 1 >>> dist.all_gather(tensor_list, tensor) >>> tensor_list """ sizes = set(x.size() for x in tensor_list) equal_size = (len(sizes) <= 1) if equal_size: return dist.all_gather(tensor_list, tensor, group, async_op) assert not async_op, "tensor 大小不一样时不支持 async_op = True" group = group or _get_default_group() rank = dist.get_rank(group=group) world_size = dist.get_world_size(group=group) assert len(tensor_list) == world_size assert tensor_list[rank].size() == tensor.size() if group is None or group is _get_default_group(): global_ranks = list(range(world_size)) else: global_ranks = [_get_global_rank(group, i) for i in range(world_size)] # gather all tensors tensor_list[rank].copy_(tensor) for i, r in enumerate(global_ranks): dist.broadcast(tensor_list[i], r, group=group)
[docs]def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, async_op=False): """ 功能与 ``torch.distributed.reduce_scatter`` 一样,但支持不同大小的 tensor。 Args: output (Tensor): 输出的 tensor input_list (list[Tensor]): 准备做 reduce scatter 的 tensors,每个 tensor 大小不一定要相同 group (ProcessGroup, optional): ProcessGroup 对象. 如果是 ``None `` 会用默认的分组 async_op (bool, optional): 是否为异步的操作 NOTE: tensor 大小不一样时不支持设置 ``async_op = True`` """ sizes = set(x.size() for x in input_list) equal_size = (len(sizes) <= 1) if equal_size: return dist.reduce_scatter(output, input_list, op, group, async_op) assert not async_op, "tensor 大小不一样时不支持 async_op = True" group = group or _get_default_group() rank = dist.get_rank(group=group) world_size = dist.get_world_size(group=group) assert len(input_list) == world_size assert input_list[rank].size() == output.size() if group is None or group is _get_default_group(): global_ranks = list(range(world_size)) else: global_ranks = [_get_global_rank(group, i) for i in range(world_size)] # reduce chunk-i to rank-i for i, r in enumerate(global_ranks): dist.reduce(input_list[i], r, op=op, group=group) output.data.copy_(input_list[rank])