import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_global_rank, _get_default_group
from collections import defaultdict
from .time import timeout as timeout_warp
from .time import TimeoutError
class CudaTimer():
def __init__(self):
self.fwd_start = torch.cuda.Event(enable_timing=True)
self.fwd_end = torch.cuda.Event(enable_timing=True)
self.bwd_start = torch.cuda.Event(enable_timing=True)
self.bwd_end = torch.cuda.Event(enable_timing=True)
self.reset()
def reset(self):
self.fwd_recorded = False
self.bwd_recorded = False
self.fwd_time = 0.
self.bwd_time = 0.
self.iters = 0
self.comm_size = 0
self.tot_iters = 0
def record_fwd_start(self):
self.finalize_fwd()
self.fwd_start.record()
def record_fwd_end(self, size):
self.comm_size += size
self.fwd_end.record()
self.fwd_recorded = True
self.iters += 1
def record_bwd_start(self):
self.finalize_bwd()
self.bwd_start.record()
def record_bwd_end(self):
self.bwd_end.record()
self.bwd_recorded = True
def finalize_fwd(self):
if self.fwd_recorded:
self.fwd_end.synchronize()
t = self.fwd_start.elapsed_time(self.fwd_end)
self.fwd_time += t
self.fwd_recorded = False
def finalize_bwd(self):
if self.bwd_recorded:
self.bwd_end.synchronize()
t = self.bwd_start.elapsed_time(self.bwd_end)
self.bwd_time += t
self.bwd_recorded = False
def finalize(self):
self.finalize_fwd()
self.finalize_bwd()
timers = defaultdict(CudaTimer)
[docs]def sync(x, dist_group=False, dim=0, equal_size=False, tag=None, enable_timer=True, log_every_steps=1, timeout=60, reduce_grad=True):
"""
allgather 输入的 tensor 并沿着指定的维度拼接在一起,支持 autograd,backward 的时候梯度会传回去
``F.sync.get_metrics`` 会返回一个字典,格式如下:
.. code-block:: python
{
"tag1": {"iters": 100, "fwd": 25, "bwd": 40, "size": 16},
"tag2": {"iters": 100, "fwd": 25, "bwd": 40, "size": 16},
}
``iters`` 代表该 tag 调用的次数,``fwd`` / ``bwd`` 代表每次 forward / backward 的平均耗时(ms),``size`` 代表每次 forward 返回结果的平均大小(byte)
Args:
x (Tensor): 输入的 tensor
dist_group (ProcessGroup): ProcessGroup 对象,如果是 ``False`` 则不会做 allgather
dim (int): allgather 之后拼接的维度
equal_size (bool): 是否每张卡上的 tensor 大小相同
tag (str): 计时的标签,每个标签在一次 forward 中只能用一次; tag 为 ``None`` 时不计时
enable_timer (bool): 是否计时
log_every_steps (int): 每多少个 step 计时一次
timeout (int): 本函数超时的秒数,超过这个时间会抛出异常;``0`` 代表没有时间限制;默认是 ``60``
reduce_grad (bool): 是否对传回来的梯度做 reduce,默认是 ``True``
Returns:
out (Tensor): 拼接后的结果
Examples:
.. code-block:: python
import torch.distributed as dist
import hfai.nn.functional as F
# init process group ...
rank = dist.get_rank()
x = torch.ones(1, requires_grad=True, device='cuda') * rank
out = F.sync(x, dist_group, dim=0, tag='tag1')
out.sum().backward()
# 打印耗时、通讯量等
F.sync.print_metrics()
# 获得 metrics
print(F.sync.get_metrics())
# 重置 metrics
F.sync.reset()
"""
enable_timer = enable_timer and (log_every_steps >= 1) and (tag is not None)
if enable_timer:
assert isinstance(tag, str), "tag 必须是一个字符串"
timers[tag].tot_iters += 1
if timers[tag].tot_iters % log_every_steps != 0:
enable_timer = False
f = timeout_warp(timeout)(SyncFunction.apply)
try:
result = f(x, dist_group, dim, equal_size, tag, enable_timer, reduce_grad)
except TimeoutError as e:
group = dist_group or _get_default_group()
rank = dist.get_rank(group=group)
world_size = dist.get_world_size(group=group)
msg = f"F.sync is timeout for {e.sec} seconds! RANK {rank} / {world_size}, " \
f"x.shape {x.shape}, dim {dim}, equal_size {equal_size}, tag {tag}"
raise RuntimeError(msg)
return result
def reset():
global timers
timers = defaultdict(CudaTimer)
def print_metrics():
metrics = get_metrics()
print(metrics, flush=True)
def get_metrics():
global timers
metrics = {}
for tag, timer in timers.items():
timer.finalize()
it = timer.iters
if it > 0:
metrics[tag] = {
'iters': it,
'fwd': timer.fwd_time / it,
'bwd': timer.bwd_time / it,
'size': timer.comm_size / it
}
return metrics
sync.reset = reset
sync.print_metrics = print_metrics
sync.get_metrics = get_metrics
class SyncFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, dist_group=False, dim=0, equal_size=False, tag=None, enable_timer=True, reduce_grad=True):
if enable_timer:
timers[tag].record_fwd_start()
ctx.tag = tag
ctx.enable_timer = enable_timer
ctx.reduce_grad = reduce_grad
out = sync_forward(ctx, x, dist_group, dim, equal_size)
if enable_timer:
size = out.numel() * out.element_size()
timers[tag].record_fwd_end(size)
return out
@staticmethod
def backward(ctx, grad_output):
if ctx.enable_timer and ctx.reduce_grad:
timers[ctx.tag].record_bwd_start()
out = sync_backward(ctx, grad_output)
if ctx.enable_timer and ctx.reduce_grad:
timers[ctx.tag].record_bwd_end()
return out
def sync_forward(ctx, x, dist_group, dim, equal_size):
ctx.dist_group = dist_group
if dist_group is False:
return x
dist_group = dist_group or _get_default_group()
rank = dist.get_rank(group=dist_group)
world_size = dist.get_world_size(group=dist_group)
if equal_size:
batch_sizes = [x.size(dim) for _ in range(world_size)]
else:
sizes = torch.zeros(world_size, dtype=torch.int32, device=x.device)
sizes[rank] = x.size(dim)
dist.all_reduce(sizes, group=dist_group)
batch_sizes = sizes.tolist()
ctx.batch_sizes = batch_sizes
ctx.dim = dim
if x.numel() // x.size(dim) * sum(batch_sizes) < 1024 * world_size:
return fwd_allreduce_impl(dim, batch_sizes, dist_group, x)
return fwd_allgather_impl(dim, batch_sizes, dist_group, x)
def sync_backward(ctx, grad_output):
dist_group = ctx.dist_group
dim, batch_sizes = ctx.dim, ctx.batch_sizes
if dist_group is False:
return grad_output, None, None, None, None, None, None
dist_group = dist_group or _get_default_group()
rank = dist.get_rank(group=dist_group)
world_size = dist.get_world_size(group=dist_group)
if not ctx.reduce_grad:
start = sum(batch_sizes[:rank])
grad = grad_output.narrow(dim, start, batch_sizes[rank])
return grad, None, None, None, None, None, None
if grad_output.numel() < 1024 * world_size:
return bwd_allreduce_impl(dim, batch_sizes, dist_group, grad_output)
if len(set(batch_sizes)) == 1:
return bwd_reducescatter_impl(dim, batch_sizes, dist_group, grad_output)
size = grad_output.numel() / (1 << 20)
nodes = world_size // 8
node2size = [(1, 4), (2, 16), (4, 64), (8, 512)]
for n, min_size in node2size:
if nodes <= n:
if size >= min_size:
return bwd_reduce_impl(dim, batch_sizes, dist_group, grad_output)
else:
return bwd_allreduce_impl(dim, batch_sizes, dist_group, grad_output)
return bwd_allreduce_impl(dim, batch_sizes, dist_group, grad_output)
def fwd_allgather_impl(dim, batch_sizes, dist_group, x):
if dim != 0:
x = x.transpose(0, dim)
rank = dist.get_rank(group=dist_group)
world_size = dist.get_world_size(group=dist_group)
max_batch_size = max(batch_sizes)
shape = list(x.size())
shape[0] = max_batch_size * world_size
result_tensor = torch.empty(shape, dtype=x.dtype, device=x.device)
tensors = result_tensor.chunk(world_size)
tensors[rank][:x.size(0)].copy_(x)
dist_group = dist_group or _get_default_group()
# gather all tensors
all_gather_base(result_tensor, tensors[rank], group=dist_group)
# unroll
tot = 0
for i in range(world_size):
if tot < i * max_batch_size: # left shift
result_tensor[tot:tot + batch_sizes[i]] = \
result_tensor[i * max_batch_size:i * max_batch_size + batch_sizes[i]].clone()
tot += batch_sizes[i]
output = result_tensor[:tot]
if dim != 0:
output = output.transpose(0, dim).contiguous()
return output
def fwd_allreduce_impl(dim, batch_sizes, dist_group, x):
rank = dist.get_rank(group=dist_group)
shape = list(x.size())
shape[dim] = sum(batch_sizes)
result_tensor = torch.zeros(shape, dtype=x.dtype, device=x.device)
start = sum(batch_sizes[:rank])
size = batch_sizes[rank]
result_tensor.narrow(dim, start, size).data.copy_(x)
dist.all_reduce(result_tensor, group=dist_group)
return result_tensor
def bwd_reduce_impl(dim, batch_sizes, dist_group, grad_output):
rank = dist.get_rank(group=dist_group)
world_size = dist.get_world_size(group=dist_group)
grads = []
i0 = 0
for i in range(world_size):
g = grad_output.narrow(dim, i0, batch_sizes[i])
grads.append(g.contiguous())
i0 += batch_sizes[i]
if dist_group is None or dist_group is _get_default_group():
global_ranks = list(range(world_size))
else:
global_ranks = [_get_global_rank(dist_group, i) for i in range(world_size)]
for i, gloabl_rank in enumerate(global_ranks):
dist.reduce(grads[i], gloabl_rank, group=dist_group)
return grads[rank], None, None, None, None, None, None
def bwd_allreduce_impl(dim, batch_sizes, dist_group, grad_output):
grad_output = grad_output.contiguous()
rank = dist.get_rank(group=dist_group)
dist.all_reduce(grad_output, group=dist_group)
start = sum(batch_sizes[:rank])
grad = grad_output.narrow(dim, start, batch_sizes[rank])
return grad, None, None, None, None, None, None
def bwd_reducescatter_impl(dim, batch_sizes, dist_group, grad_output):
if dim != 0:
grad_output = grad_output.transpose(0, dim)
grad_output = grad_output.contiguous()
shape = list(grad_output.shape)
shape[0] = batch_sizes[0]
grad = grad_output.new_empty(shape)
reduce_scatter_base(grad, grad_output, group=dist_group)
if dim != 0:
grad = grad.transpose(0, dim)
return grad, None, None, None, None, None, None
def all_gather_base(output, input, group):
if hasattr(dist, "_all_gather_base"):
return dist._all_gather_base(output, input, group=group)
chunks = list(output.chunk(group.size()))
dist.all_gather(chunks, input, group=group)
def reduce_scatter_base(output, input, group):
if hasattr(dist, "_reduce_scatter_base"):
return dist._reduce_scatter_base(output, input, group=group)
chunks = list(input.chunk(group.size()))
dist.reduce_scatter(output, chunks, group=group)