Shortcuts

Source code for haiscale.pipeline.comm

import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_global_rank, _get_default_group
from ..timer import cuda_timer


DTYPES = [
    torch.bool,
    torch.uint8,
    torch.int8,
    torch.int16,
    torch.int32,
    torch.int64,
    torch.float16,
    torch.float32,
    torch.float64,
    torch.bfloat16,
]

META_SIZE = 1024
NONE_MARK = -2


TENSOR_SHAPES = None
TENSOR_DTYPE = None


[docs]def make_subgroups(pp_size, dp_size=None, group=None): """ 划分 pipeline parallel 和 data parallel 的进程组 Args: pp_size (int): pipeline parallel 的大小 dp_size (int): data parallel 的大小,默认是 ``group.size() / pp_size`` group (dist.ProcessGroup): 进程组,默认使用全局的进程组 Returns: tuple: ``(dp_group, pp_group)`` Examples: .. code-block:: python from haiscale.pipeline import GPipe, partition, make_subgroups dist.init_process_group(...) torch.cuda.set_device(local_rank) rank, world_size = dist.get_rank(), dist.get_world_size() dp_group, pp_group = make_subgroups(pp_size=4) model = nn.Sequential(...).cuda() model = partition(model, pp_group.rank(), pp_group.size()) model = DDP(model, device_ids=[local_rank], process_group=dp_group) model = GPipe(model, chunks=64, process_group=pp_group) for x, y in dataloader: out = model(x) loss = ((out - y)**2).sum() loss.backward() """ assert isinstance(pp_size, int) assert dp_size is None or isinstance(dp_size, int) assert group is None or isinstance(group, dist.ProcessGroup) group = group or _get_default_group() rank = group.rank() nranks = group.size() if dp_size is None: assert nranks % pp_size == 0, (nranks, pp_size) dp_size = nranks // pp_size else: assert dp_size * pp_size == nranks my_dp_rank = rank // pp_size # pp group id my_pp_rank = rank % pp_size # dp group id my_dp_group = None my_pp_group = None # make dp groups for i in range(pp_size): dp_ranks = [i + j * pp_size for j in range(dp_size)] dp_ranks = [get_global_rank(group, r) for r in dp_ranks] dp_group = dist.new_group(dp_ranks) if i == my_pp_rank: my_dp_group = dp_group # make pp groups for i in range(dp_size): pp_ranks = [i * pp_size + j for j in range(pp_size)] pp_ranks = [get_global_rank(group, r) for r in pp_ranks] pp_group = dist.new_group(pp_ranks) if i == my_dp_rank: my_pp_group = pp_group assert my_dp_group is not None and my_pp_group is not None return my_dp_group, my_pp_group
@cuda_timer.record("sendrecv_twice") def sendrecv_twice(send1, recv1, peer1, send2, recv2, peer2, group): if not any([send1 is not None, recv1, send2 is not None, recv2]): return None, None recv1, recv2 = sendrecv_twice_meta(send1, recv1, peer1, send2, recv2, peer2, group) f = lambda x: x is not None to_be_sent1 = send1 if send1 is None else list(filter(f, send1)) to_be_sent2 = send2 if send2 is None else list(filter(f, send2)) to_be_recv1 = recv1 if recv1 is None else list(filter(f, recv1)) to_be_recv2 = recv2 if recv2 is None else list(filter(f, recv2)) sendrecv_twice_impl(to_be_sent1, to_be_recv1, peer1, to_be_sent2, to_be_recv2, peer2, group) sends, recvs = [], [] if to_be_sent1 is not None: sends += to_be_sent1 if to_be_sent2: sends += to_be_sent2 if to_be_recv1: recvs += to_be_recv1 if to_be_recv2: recvs += to_be_recv2 cuda_timer.record_send("sendrecv_twice", sends) cuda_timer.record_recv("sendrecv_twice", recvs) return recv1, recv2 @cuda_timer.record("sendrecv_twice_meta") def sendrecv_twice_meta(send1, recv1, peer1, send2, recv2, peer2, group): if TENSOR_SHAPES is None: send1_meta, recv1_meta = None, None send2_meta, recv2_meta = None, None if send1 is not None: send1_meta = build_empty_meta() encode_meta(send1, send1_meta) if recv1: recv1_meta = build_empty_meta() if send2 is not None: send2_meta = build_empty_meta() encode_meta(send2, send2_meta) if recv2: recv2_meta = build_empty_meta() sendrecv_twice_impl(send1_meta, recv1_meta, peer1, send2_meta, recv2_meta, peer2, group) recv1 = build_tensors_from_meta(decoce_meta(recv1_meta)) if recv1 else None recv2 = build_tensors_from_meta(decoce_meta(recv2_meta)) if recv2 else None else: recv1 = build_from_tensor_shapes() if recv1 else None recv2 = build_from_tensor_shapes() if recv2 else None return recv1, recv2 def sendrecv_twice_impl(send1, recv1, peer1, send2, recv2, peer2, group): peer1 = get_global_rank(group, peer1) peer2 = get_global_rank(group, peer2) def _warp_list(x): if isinstance(x, torch.Tensor): return [x] return x send1 = _warp_list(send1) recv1 = _warp_list(recv1) send2 = _warp_list(send2) recv2 = _warp_list(recv2) ops = [] if send1 is not None: ops += [dist.P2POp(dist.isend, x, peer1) for x in send1] if send2 is not None: ops += [dist.P2POp(dist.isend, x, peer2) for x in send2] # NOTE: if peer1 is peer2, we should do recv2 first if recv2 is not None: ops += [dist.P2POp(dist.irecv, x, peer2) for x in recv2] if recv1 is not None: ops += [dist.P2POp(dist.irecv, x, peer1) for x in recv1] assert len(ops) > 0 reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait() torch.cuda.synchronize() @cuda_timer.record("sendrecv") def sendrecv(send_tensors, peer, group): recv_tensors = sendrecv_meta(send_tensors, peer, group) f = lambda x: x is not None to_be_sent = list(filter(f, send_tensors)) to_be_recv = list(filter(f, recv_tensors)) sendrecv_impl(to_be_sent, to_be_recv, peer, group) cuda_timer.record_send("sendrecv", to_be_sent) cuda_timer.record_recv("sendrecv", to_be_recv) return recv_tensors @cuda_timer.record("sendrecv_meta") def sendrecv_meta(tensors, peer, group): if TENSOR_SHAPES is None: send_meta = build_empty_meta() recv_meta = build_empty_meta() encode_meta(tensors, send_meta) sendrecv_impl([send_meta], [recv_meta], peer, group) recv_tensors_meta = decoce_meta(recv_meta) tensors = build_tensors_from_meta(recv_tensors_meta) else: tensors = build_from_tensor_shapes() return tensors def sendrecv_impl(send_tensors, recv_tensors, peer, group): peer = get_global_rank(group, peer) ops = [dist.P2POp(dist.isend, x, peer, group=group) for x in send_tensors] ops += [dist.P2POp(dist.irecv, x, peer, group=group) for x in recv_tensors] reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait() # There is a weird BUG in pytorch nccl... # we need to synchronize ncclStreams manually torch.cuda.synchronize() def broadcast(tensors, src, group): if isinstance(tensors, torch.Tensor): tensors = (tensors,) tensors_meta = broadcast_meta(tensors, src, group) if group.rank() != src: tensors = build_tensors_from_meta(tensors_meta) src = get_global_rank(group, src) for tensor in tensors: dist.broadcast(tensor, src, group) return tensors def broadcast_meta(tensors, src, group): meta = build_empty_meta() if group.rank() == src: encode_meta(tensors, meta) src = get_global_rank(group, src) dist.broadcast(meta, src, group) tensors_meta = decoce_meta(meta) return tensors_meta def encode_meta(tensors, meta): length = sum(tensor.dim() + 3 for tensor in tensors if tensor is not None) num_nones = sum(1 for tensor in tensors if tensor is None) assert length + num_nones + 1 <= meta.size(0) cnt = 1 meta[0] = len(tensors) for tensor in tensors: if tensor is None: meta[cnt] = NONE_MARK cnt += 1 else: dim = tensor.dim() meta[cnt] = dim meta[cnt + 1] = DTYPES.index(tensor.dtype) meta[(cnt + 2):(cnt + 2 + dim)] = torch.tensor(tensor.shape) meta[cnt + 2 + dim] = int(tensor.requires_grad) cnt += 3 + dim def decoce_meta(meta): tensors_meta = [] num_tensors = meta[0].item() cnt = 1 for _ in range(num_tensors): if meta[cnt].item() == NONE_MARK: tensors_meta.append(None) cnt += 1 else: dim = meta[cnt].item() dtype = DTYPES[meta[cnt + 1].item()] shape = meta[(cnt + 2):(cnt + 2 + dim)] requires_grad = bool(meta[cnt + 2 + dim].item()) tensors_meta.append((torch.Size(shape), dtype, requires_grad)) cnt += 3 + dim return tensors_meta @cuda_timer.record("recv") def recv(src, group): tensors = recv_meta(src, group) src = get_global_rank(group, src) for tensor in tensors: if tensor is not None: dist.recv(tensor, src, group) cuda_timer.record_recv("recv", [tensor]) return tensors @cuda_timer.record("send") def send(tensors, dst, group): send_meta(tensors, dst, group) dst = get_global_rank(group, dst) for tensor in tensors: if tensor is not None: dist.send(tensor, dst, group) cuda_timer.record_send("send", [tensor]) @cuda_timer.record("recv_meta") def recv_meta(src, group): if TENSOR_SHAPES is None: meta = build_empty_meta() src = get_global_rank(group, src) dist.recv(meta, src=src, group=group) tensors_meta = decoce_meta(meta) tensors = build_tensors_from_meta(tensors_meta) else: tensors = build_from_tensor_shapes() return tensors @cuda_timer.record("send_meta") def send_meta(tensors, dst, group): if TENSOR_SHAPES is None: meta = build_empty_meta() encode_meta(tensors, meta) dst = get_global_rank(group, dst) dist.send(meta, dst, group) def build_tensors_from_meta(meta): tensors = [] for m in meta: if m is None: tensors.append(None) else: tensor = torch.empty(size=m[0], dtype=m[1], requires_grad=m[2], device="cuda") tensors.append(tensor) return tensors def build_empty_meta(): return torch.full(size=(META_SIZE,), fill_value=-1, dtype=torch.int32, device="cuda") def get_global_rank(group, rank): if group is not _get_default_group(): rank = _get_global_rank(group, rank) return rank def build_from_tensor_shapes(): return [torch.empty(s, dtype=TENSOR_DTYPE, device="cuda", requires_grad=True) for s in TENSOR_SHAPES]