import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_global_rank, _get_default_group
from ..timer import cuda_timer
DTYPES = [
torch.bool,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
]
META_SIZE = 1024
NONE_MARK = -2
TENSOR_SHAPES = None
TENSOR_DTYPE = None
[docs]def make_subgroups(pp_size, dp_size=None, group=None):
"""
划分 pipeline parallel 和 data parallel 的进程组
Args:
pp_size (int): pipeline parallel 的大小
dp_size (int): data parallel 的大小,默认是 ``group.size() / pp_size``
group (dist.ProcessGroup): 进程组,默认使用全局的进程组
Returns:
tuple: ``(dp_group, pp_group)``
Examples:
.. code-block:: python
from haiscale.pipeline import GPipe, partition, make_subgroups
dist.init_process_group(...)
torch.cuda.set_device(local_rank)
rank, world_size = dist.get_rank(), dist.get_world_size()
dp_group, pp_group = make_subgroups(pp_size=4)
model = nn.Sequential(...).cuda()
model = partition(model, pp_group.rank(), pp_group.size())
model = DDP(model, device_ids=[local_rank], process_group=dp_group)
model = GPipe(model, chunks=64, process_group=pp_group)
for x, y in dataloader:
out = model(x)
loss = ((out - y)**2).sum()
loss.backward()
"""
assert isinstance(pp_size, int)
assert dp_size is None or isinstance(dp_size, int)
assert group is None or isinstance(group, dist.ProcessGroup)
group = group or _get_default_group()
rank = group.rank()
nranks = group.size()
if dp_size is None:
assert nranks % pp_size == 0, (nranks, pp_size)
dp_size = nranks // pp_size
else:
assert dp_size * pp_size == nranks
my_dp_rank = rank // pp_size # pp group id
my_pp_rank = rank % pp_size # dp group id
my_dp_group = None
my_pp_group = None
# make dp groups
for i in range(pp_size):
dp_ranks = [i + j * pp_size for j in range(dp_size)]
dp_ranks = [get_global_rank(group, r) for r in dp_ranks]
dp_group = dist.new_group(dp_ranks)
if i == my_pp_rank:
my_dp_group = dp_group
# make pp groups
for i in range(dp_size):
pp_ranks = [i * pp_size + j for j in range(pp_size)]
pp_ranks = [get_global_rank(group, r) for r in pp_ranks]
pp_group = dist.new_group(pp_ranks)
if i == my_dp_rank:
my_pp_group = pp_group
assert my_dp_group is not None and my_pp_group is not None
return my_dp_group, my_pp_group
@cuda_timer.record("sendrecv_twice")
def sendrecv_twice(send1, recv1, peer1, send2, recv2, peer2, group):
if not any([send1 is not None, recv1, send2 is not None, recv2]):
return None, None
recv1, recv2 = sendrecv_twice_meta(send1, recv1, peer1, send2, recv2, peer2, group)
f = lambda x: x is not None
to_be_sent1 = send1 if send1 is None else list(filter(f, send1))
to_be_sent2 = send2 if send2 is None else list(filter(f, send2))
to_be_recv1 = recv1 if recv1 is None else list(filter(f, recv1))
to_be_recv2 = recv2 if recv2 is None else list(filter(f, recv2))
sendrecv_twice_impl(to_be_sent1, to_be_recv1, peer1, to_be_sent2, to_be_recv2, peer2, group)
sends, recvs = [], []
if to_be_sent1 is not None:
sends += to_be_sent1
if to_be_sent2:
sends += to_be_sent2
if to_be_recv1:
recvs += to_be_recv1
if to_be_recv2:
recvs += to_be_recv2
cuda_timer.record_send("sendrecv_twice", sends)
cuda_timer.record_recv("sendrecv_twice", recvs)
return recv1, recv2
@cuda_timer.record("sendrecv_twice_meta")
def sendrecv_twice_meta(send1, recv1, peer1, send2, recv2, peer2, group):
if TENSOR_SHAPES is None:
send1_meta, recv1_meta = None, None
send2_meta, recv2_meta = None, None
if send1 is not None:
send1_meta = build_empty_meta()
encode_meta(send1, send1_meta)
if recv1:
recv1_meta = build_empty_meta()
if send2 is not None:
send2_meta = build_empty_meta()
encode_meta(send2, send2_meta)
if recv2:
recv2_meta = build_empty_meta()
sendrecv_twice_impl(send1_meta, recv1_meta, peer1, send2_meta, recv2_meta, peer2, group)
recv1 = build_tensors_from_meta(decoce_meta(recv1_meta)) if recv1 else None
recv2 = build_tensors_from_meta(decoce_meta(recv2_meta)) if recv2 else None
else:
recv1 = build_from_tensor_shapes() if recv1 else None
recv2 = build_from_tensor_shapes() if recv2 else None
return recv1, recv2
def sendrecv_twice_impl(send1, recv1, peer1, send2, recv2, peer2, group):
peer1 = get_global_rank(group, peer1)
peer2 = get_global_rank(group, peer2)
def _warp_list(x):
if isinstance(x, torch.Tensor):
return [x]
return x
send1 = _warp_list(send1)
recv1 = _warp_list(recv1)
send2 = _warp_list(send2)
recv2 = _warp_list(recv2)
ops = []
if send1 is not None:
ops += [dist.P2POp(dist.isend, x, peer1) for x in send1]
if send2 is not None:
ops += [dist.P2POp(dist.isend, x, peer2) for x in send2]
# NOTE: if peer1 is peer2, we should do recv2 first
if recv2 is not None:
ops += [dist.P2POp(dist.irecv, x, peer2) for x in recv2]
if recv1 is not None:
ops += [dist.P2POp(dist.irecv, x, peer1) for x in recv1]
assert len(ops) > 0
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
@cuda_timer.record("sendrecv")
def sendrecv(send_tensors, peer, group):
recv_tensors = sendrecv_meta(send_tensors, peer, group)
f = lambda x: x is not None
to_be_sent = list(filter(f, send_tensors))
to_be_recv = list(filter(f, recv_tensors))
sendrecv_impl(to_be_sent, to_be_recv, peer, group)
cuda_timer.record_send("sendrecv", to_be_sent)
cuda_timer.record_recv("sendrecv", to_be_recv)
return recv_tensors
@cuda_timer.record("sendrecv_meta")
def sendrecv_meta(tensors, peer, group):
if TENSOR_SHAPES is None:
send_meta = build_empty_meta()
recv_meta = build_empty_meta()
encode_meta(tensors, send_meta)
sendrecv_impl([send_meta], [recv_meta], peer, group)
recv_tensors_meta = decoce_meta(recv_meta)
tensors = build_tensors_from_meta(recv_tensors_meta)
else:
tensors = build_from_tensor_shapes()
return tensors
def sendrecv_impl(send_tensors, recv_tensors, peer, group):
peer = get_global_rank(group, peer)
ops = [dist.P2POp(dist.isend, x, peer, group=group) for x in send_tensors]
ops += [dist.P2POp(dist.irecv, x, peer, group=group) for x in recv_tensors]
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# There is a weird BUG in pytorch nccl...
# we need to synchronize ncclStreams manually
torch.cuda.synchronize()
def broadcast(tensors, src, group):
if isinstance(tensors, torch.Tensor):
tensors = (tensors,)
tensors_meta = broadcast_meta(tensors, src, group)
if group.rank() != src:
tensors = build_tensors_from_meta(tensors_meta)
src = get_global_rank(group, src)
for tensor in tensors:
dist.broadcast(tensor, src, group)
return tensors
def broadcast_meta(tensors, src, group):
meta = build_empty_meta()
if group.rank() == src:
encode_meta(tensors, meta)
src = get_global_rank(group, src)
dist.broadcast(meta, src, group)
tensors_meta = decoce_meta(meta)
return tensors_meta
def encode_meta(tensors, meta):
length = sum(tensor.dim() + 3 for tensor in tensors if tensor is not None)
num_nones = sum(1 for tensor in tensors if tensor is None)
assert length + num_nones + 1 <= meta.size(0)
cnt = 1
meta[0] = len(tensors)
for tensor in tensors:
if tensor is None:
meta[cnt] = NONE_MARK
cnt += 1
else:
dim = tensor.dim()
meta[cnt] = dim
meta[cnt + 1] = DTYPES.index(tensor.dtype)
meta[(cnt + 2):(cnt + 2 + dim)] = torch.tensor(tensor.shape)
meta[cnt + 2 + dim] = int(tensor.requires_grad)
cnt += 3 + dim
def decoce_meta(meta):
tensors_meta = []
num_tensors = meta[0].item()
cnt = 1
for _ in range(num_tensors):
if meta[cnt].item() == NONE_MARK:
tensors_meta.append(None)
cnt += 1
else:
dim = meta[cnt].item()
dtype = DTYPES[meta[cnt + 1].item()]
shape = meta[(cnt + 2):(cnt + 2 + dim)]
requires_grad = bool(meta[cnt + 2 + dim].item())
tensors_meta.append((torch.Size(shape), dtype, requires_grad))
cnt += 3 + dim
return tensors_meta
@cuda_timer.record("recv")
def recv(src, group):
tensors = recv_meta(src, group)
src = get_global_rank(group, src)
for tensor in tensors:
if tensor is not None:
dist.recv(tensor, src, group)
cuda_timer.record_recv("recv", [tensor])
return tensors
@cuda_timer.record("send")
def send(tensors, dst, group):
send_meta(tensors, dst, group)
dst = get_global_rank(group, dst)
for tensor in tensors:
if tensor is not None:
dist.send(tensor, dst, group)
cuda_timer.record_send("send", [tensor])
@cuda_timer.record("recv_meta")
def recv_meta(src, group):
if TENSOR_SHAPES is None:
meta = build_empty_meta()
src = get_global_rank(group, src)
dist.recv(meta, src=src, group=group)
tensors_meta = decoce_meta(meta)
tensors = build_tensors_from_meta(tensors_meta)
else:
tensors = build_from_tensor_shapes()
return tensors
@cuda_timer.record("send_meta")
def send_meta(tensors, dst, group):
if TENSOR_SHAPES is None:
meta = build_empty_meta()
encode_meta(tensors, meta)
dst = get_global_rank(group, dst)
dist.send(meta, dst, group)
def build_tensors_from_meta(meta):
tensors = []
for m in meta:
if m is None:
tensors.append(None)
else:
tensor = torch.empty(size=m[0], dtype=m[1], requires_grad=m[2], device="cuda")
tensors.append(tensor)
return tensors
def build_empty_meta():
return torch.full(size=(META_SIZE,), fill_value=-1, dtype=torch.int32, device="cuda")
def get_global_rank(group, rank):
if group is not _get_default_group():
rank = _get_global_rank(group, rank)
return rank
def build_from_tensor_shapes():
return [torch.empty(s, dtype=TENSOR_DTYPE, device="cuda", requires_grad=True) for s in TENSOR_SHAPES]