Shortcuts

Source code for hfai.nn.sync_function.syncfunc

import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_global_rank, _get_default_group
from collections import defaultdict
from .time import timeout as timeout_warp
from .time import TimeoutError


class CudaTimer():

    def __init__(self):
        self.fwd_start = torch.cuda.Event(enable_timing=True)
        self.fwd_end = torch.cuda.Event(enable_timing=True)
        self.bwd_start = torch.cuda.Event(enable_timing=True)
        self.bwd_end = torch.cuda.Event(enable_timing=True)

        self.reset()

    def reset(self):
        self.fwd_recorded = False
        self.bwd_recorded = False

        self.fwd_time = 0.
        self.bwd_time = 0.
        self.iters = 0
        self.comm_size = 0

        self.tot_iters = 0

    def record_fwd_start(self):
        self.finalize_fwd()
        self.fwd_start.record()

    def record_fwd_end(self, size):
        self.comm_size += size
        self.fwd_end.record()
        self.fwd_recorded = True
        self.iters += 1

    def record_bwd_start(self):
        self.finalize_bwd()
        self.bwd_start.record()

    def record_bwd_end(self):
        self.bwd_end.record()
        self.bwd_recorded = True

    def finalize_fwd(self):
        if self.fwd_recorded:
            self.fwd_end.synchronize()

            t = self.fwd_start.elapsed_time(self.fwd_end)
            self.fwd_time += t
            self.fwd_recorded = False

    def finalize_bwd(self):
        if self.bwd_recorded:
            self.bwd_end.synchronize()

            t = self.bwd_start.elapsed_time(self.bwd_end)
            self.bwd_time += t
            self.bwd_recorded = False

    def finalize(self):
        self.finalize_fwd()
        self.finalize_bwd()


timers = defaultdict(CudaTimer)


[docs]def sync(x, dist_group=False, dim=0, equal_size=False, tag=None, enable_timer=True, log_every_steps=1, timeout=60, reduce_grad=True): """ allgather 输入的 tensor 并沿着指定的维度拼接在一起,支持 autograd,backward 的时候梯度会传回去 ``F.sync.get_metrics`` 会返回一个字典,格式如下: .. code-block:: python { "tag1": {"iters": 100, "fwd": 25, "bwd": 40, "size": 16}, "tag2": {"iters": 100, "fwd": 25, "bwd": 40, "size": 16}, } ``iters`` 代表该 tag 调用的次数,``fwd`` / ``bwd`` 代表每次 forward / backward 的平均耗时(ms),``size`` 代表每次 forward 返回结果的平均大小(byte) Args: x (Tensor): 输入的 tensor dist_group (ProcessGroup): ProcessGroup 对象,如果是 ``False`` 则不会做 allgather dim (int): allgather 之后拼接的维度 equal_size (bool): 是否每张卡上的 tensor 大小相同 tag (str): 计时的标签,每个标签在一次 forward 中只能用一次; tag 为 ``None`` 时不计时 enable_timer (bool): 是否计时 log_every_steps (int): 每多少个 step 计时一次 timeout (int): 本函数超时的秒数,超过这个时间会抛出异常;``0`` 代表没有时间限制;默认是 ``60`` reduce_grad (bool): 是否对传回来的梯度做 reduce,默认是 ``True`` Returns: out (Tensor): 拼接后的结果 Examples: .. code-block:: python import torch.distributed as dist import hfai.nn.functional as F # init process group ... rank = dist.get_rank() x = torch.ones(1, requires_grad=True, device='cuda') * rank out = F.sync(x, dist_group, dim=0, tag='tag1') out.sum().backward() # 打印耗时、通讯量等 F.sync.print_metrics() # 获得 metrics print(F.sync.get_metrics()) # 重置 metrics F.sync.reset() """ enable_timer = enable_timer and (log_every_steps >= 1) and (tag is not None) if enable_timer: assert isinstance(tag, str), "tag 必须是一个字符串" timers[tag].tot_iters += 1 if timers[tag].tot_iters % log_every_steps != 0: enable_timer = False f = timeout_warp(timeout)(SyncFunction.apply) try: result = f(x, dist_group, dim, equal_size, tag, enable_timer, reduce_grad) except TimeoutError as e: group = dist_group or _get_default_group() rank = dist.get_rank(group=group) world_size = dist.get_world_size(group=group) msg = f"F.sync is timeout for {e.sec} seconds! RANK {rank} / {world_size}, " \ f"x.shape {x.shape}, dim {dim}, equal_size {equal_size}, tag {tag}" raise RuntimeError(msg) return result
def reset(): global timers timers = defaultdict(CudaTimer) def print_metrics(): metrics = get_metrics() print(metrics, flush=True) def get_metrics(): global timers metrics = {} for tag, timer in timers.items(): timer.finalize() it = timer.iters if it > 0: metrics[tag] = { 'iters': it, 'fwd': timer.fwd_time / it, 'bwd': timer.bwd_time / it, 'size': timer.comm_size / it } return metrics sync.reset = reset sync.print_metrics = print_metrics sync.get_metrics = get_metrics class SyncFunction(torch.autograd.Function): @staticmethod def forward(ctx, x, dist_group=False, dim=0, equal_size=False, tag=None, enable_timer=True, reduce_grad=True): if enable_timer: timers[tag].record_fwd_start() ctx.tag = tag ctx.enable_timer = enable_timer ctx.reduce_grad = reduce_grad out = sync_forward(ctx, x, dist_group, dim, equal_size) if enable_timer: size = out.numel() * out.element_size() timers[tag].record_fwd_end(size) return out @staticmethod def backward(ctx, grad_output): if ctx.enable_timer and ctx.reduce_grad: timers[ctx.tag].record_bwd_start() out = sync_backward(ctx, grad_output) if ctx.enable_timer and ctx.reduce_grad: timers[ctx.tag].record_bwd_end() return out def sync_forward(ctx, x, dist_group, dim, equal_size): ctx.dist_group = dist_group if dist_group is False: return x dist_group = dist_group or _get_default_group() rank = dist.get_rank(group=dist_group) world_size = dist.get_world_size(group=dist_group) if equal_size: batch_sizes = [x.size(dim) for _ in range(world_size)] else: sizes = torch.zeros(world_size, dtype=torch.int32, device=x.device) sizes[rank] = x.size(dim) dist.all_reduce(sizes, group=dist_group) batch_sizes = sizes.tolist() ctx.batch_sizes = batch_sizes ctx.dim = dim if x.numel() // x.size(dim) * sum(batch_sizes) < 1024 * world_size: return fwd_allreduce_impl(dim, batch_sizes, dist_group, x) return fwd_allgather_impl(dim, batch_sizes, dist_group, x) def sync_backward(ctx, grad_output): dist_group = ctx.dist_group dim, batch_sizes = ctx.dim, ctx.batch_sizes if dist_group is False: return grad_output, None, None, None, None, None, None dist_group = dist_group or _get_default_group() rank = dist.get_rank(group=dist_group) world_size = dist.get_world_size(group=dist_group) if not ctx.reduce_grad: start = sum(batch_sizes[:rank]) grad = grad_output.narrow(dim, start, batch_sizes[rank]) return grad, None, None, None, None, None, None if grad_output.numel() < 1024 * world_size: return bwd_allreduce_impl(dim, batch_sizes, dist_group, grad_output) if len(set(batch_sizes)) == 1: return bwd_reducescatter_impl(dim, batch_sizes, dist_group, grad_output) size = grad_output.numel() / (1 << 20) nodes = world_size // 8 node2size = [(1, 4), (2, 16), (4, 64), (8, 512)] for n, min_size in node2size: if nodes <= n: if size >= min_size: return bwd_reduce_impl(dim, batch_sizes, dist_group, grad_output) else: return bwd_allreduce_impl(dim, batch_sizes, dist_group, grad_output) return bwd_allreduce_impl(dim, batch_sizes, dist_group, grad_output) def fwd_allgather_impl(dim, batch_sizes, dist_group, x): if dim != 0: x = x.transpose(0, dim) rank = dist.get_rank(group=dist_group) world_size = dist.get_world_size(group=dist_group) max_batch_size = max(batch_sizes) shape = list(x.size()) shape[0] = max_batch_size * world_size result_tensor = torch.empty(shape, dtype=x.dtype, device=x.device) tensors = result_tensor.chunk(world_size) tensors[rank][:x.size(0)].copy_(x) dist_group = dist_group or _get_default_group() # gather all tensors all_gather_base(result_tensor, tensors[rank], group=dist_group) # unroll tot = 0 for i in range(world_size): if tot < i * max_batch_size: # left shift result_tensor[tot:tot + batch_sizes[i]] = \ result_tensor[i * max_batch_size:i * max_batch_size + batch_sizes[i]].clone() tot += batch_sizes[i] output = result_tensor[:tot] if dim != 0: output = output.transpose(0, dim).contiguous() return output def fwd_allreduce_impl(dim, batch_sizes, dist_group, x): rank = dist.get_rank(group=dist_group) shape = list(x.size()) shape[dim] = sum(batch_sizes) result_tensor = torch.zeros(shape, dtype=x.dtype, device=x.device) start = sum(batch_sizes[:rank]) size = batch_sizes[rank] result_tensor.narrow(dim, start, size).data.copy_(x) dist.all_reduce(result_tensor, group=dist_group) return result_tensor def bwd_reduce_impl(dim, batch_sizes, dist_group, grad_output): rank = dist.get_rank(group=dist_group) world_size = dist.get_world_size(group=dist_group) grads = [] i0 = 0 for i in range(world_size): g = grad_output.narrow(dim, i0, batch_sizes[i]) grads.append(g.contiguous()) i0 += batch_sizes[i] if dist_group is None or dist_group is _get_default_group(): global_ranks = list(range(world_size)) else: global_ranks = [_get_global_rank(dist_group, i) for i in range(world_size)] for i, gloabl_rank in enumerate(global_ranks): dist.reduce(grads[i], gloabl_rank, group=dist_group) return grads[rank], None, None, None, None, None, None def bwd_allreduce_impl(dim, batch_sizes, dist_group, grad_output): grad_output = grad_output.contiguous() rank = dist.get_rank(group=dist_group) dist.all_reduce(grad_output, group=dist_group) start = sum(batch_sizes[:rank]) grad = grad_output.narrow(dim, start, batch_sizes[rank]) return grad, None, None, None, None, None, None def bwd_reducescatter_impl(dim, batch_sizes, dist_group, grad_output): if dim != 0: grad_output = grad_output.transpose(0, dim) grad_output = grad_output.contiguous() shape = list(grad_output.shape) shape[0] = batch_sizes[0] grad = grad_output.new_empty(shape) reduce_scatter_base(grad, grad_output, group=dist_group) if dim != 0: grad = grad.transpose(0, dim) return grad, None, None, None, None, None, None def all_gather_base(output, input, group): if hasattr(dist, "_all_gather_base"): return dist._all_gather_base(output, input, group=group) chunks = list(output.chunk(group.size())) dist.all_gather(chunks, input, group=group) def reduce_scatter_base(output, input, group): if hasattr(dist, "_reduce_scatter_base"): return dist._reduce_scatter_base(output, input, group=group) chunks = list(input.chunk(group.size())) dist.reduce_scatter(output, chunks, group=group)