from collections import defaultdict
import torch
import torch.nn as nn
import torch.distributed as dist
from . import comm, microbatch
from .utils import checkpoint, sync_forward, sync_backward, run_backward, equal_across_ranks
from ..timer import cuda_timer
"""
nranks = 4
num_model_chunks = 2
stage 0: F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 FA-B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 BA
stage 1: F0 F1 F2 F3 F4 F5 F6 F7 F8-B0 F9-B1 FA-B2 B3 B4 B5 B6 B7 B8 B9 BA
stage 2: F0 F1 F2 F3 F4 F5 F6-B0 F7-B1 F8-B2 F9-B3 FA-B4 B5 B6 B7 B8 B9 BA
stage 3: F0 F1 F2 F3 F4-B0 F5-B1 F6-B2 F7-B3 F8-B4 F9-B5 FA-B6 B7 B8 B9 BA
<------------------> <---------------------------------------> <------------------>
warmup 1F1B cooldown
stage 0: F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 FA-B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 BA BB
stage 1: F0 F1 F2 F3 F4 F5 F6 F7 F8-B0 F9-B1 FA-B2 B3 B4 B5 B6 B7 B8 B9 BA BB
stage 2: F0 F1 F2 F3 F4 F5 F6-B0 F7-B1 F8-B2 F9-B3 FA-B4 B5 B6 B7 B8 B9 BA BB
stage 3: F0 F1 F2 F3 F4-B0 F5-B1 F6-B2 F7-B3 F8-B4 F9-B5 FA-B6 B7 B8 B9 BA BB
<------------------> <---------------------------------------> <---------------------->
warmup 1F1B cooldown
1. warmup
F 开始前 recv(input, prev_rank)
F 结束后 send(output, next_rank),不包括最后一个 model chunk 或者最后一个 F
2. 1F1B
1F1B 开始前:
第一个 1F1B: recv(input, prev_rank), sendrecv(output, grad_output, next_rank)
其他: sendrecv(grad_input, input, prev_rank), sendrecv(output, grad_output, next_rank)
最后一个 1F1B 结束后:
send(grad_input, prev_rank), sendrecv(output, grad_output, next_rank)
3. cooldown
B 开始前 recv(grad_output, next_rank)
B 结束后 send(grad_input, prev_rank)
num_chunks % (num_model_chunks * nranks) == 0
"""
[docs]class Interleaved1F1B(nn.Module):
"""
分布式流水线并行算法 interleaved 1F1B
Interleaved1F1B 需要把模型切分成 ``world_size * num_model_chunks`` 个子模型,每块 GPU 上有 ``num_model_chunks`` 个子模型;
其中 ``num_model_chunks`` 取决于用户的设置。
传入 Interleaved1F1B 的模型可以有多个输入和输出,但相邻 stage 模型之间的输入输出要能对应上。
Args:
modules (List[nn.Module]): 切分后的模型,可以通过 :func:`partition` 分割得到
chunks (int): microbatch 的数量
process_group (dist.ProcessGroup): 流水线并行的进程组,默认使用全局的进程组
batch_dim (int): batch 的维度,默认是 ``0``
checkpoint (bool): 是否使用 activation checkpoint,默认是 ``False``
Example:
训练的时候需要调用 ``forward_backward`` 并传入损失函数 ``criterion`` 和标签 ``labels``:
.. code-block:: python
from haiscale.pipeline import Interleaved1F1B, partition
dist.init_process_group(...)
torch.cuda.set_device(local_rank)
rank, world_size = dist.get_rank(), dist.get_world_size()
torch.manual_seed(12345)
def loss_fn(out, y):
return ((out - y)**2).sum()
model = nn.Sequential(...).cuda()
modules = partition(model, rank, world_size, num_model_chunks=2) # len(modules) = 2
model = Interleaved1F1B(modules, chunks=32)
for x, y in dataloader:
loss, _ = model.forward_backward(x, criterion=loss_fn, labels=(y,))
# eval
with torch.no_grad():
out = model(x)
if rank == world_size - 1:
# calculate metrics ...
NOTE:
示例中的第一个 stage 和最后一个 stage 的 dataloader 应该要能产生相同的数据。
NOTE:
如果要和 DDP 一起使用,需要把分割后的模型先传入 DDP,然后再把 DDP 模型传入 GPipe。
通过 :func:`make_subgroups` 可以设置 pipeline parallel 和 data parallel 的大小。
"""
def __init__(self, modules, chunks, process_group=None, batch_dim=0, checkpoint=False):
super().__init__()
assert isinstance(modules, (list, tuple))
assert len(modules) >= 2
assert all(isinstance(m, nn.Module) for m in modules)
assert isinstance(chunks, int)
assert isinstance(batch_dim, int)
assert isinstance(checkpoint, bool)
group = process_group
assert group is None or isinstance(group, dist.ProcessGroup)
group = group or dist.distributed_c10d._get_default_group()
assert group.size() > 1
device = torch.cuda.current_device()
device = torch.device(device)
assert next(modules[0].parameters()).device == device
self.device = device
self.module = nn.ModuleList(modules)
self.chunks = chunks
self.group = group
self.batch_dim = batch_dim
self.model_chunks = len(modules)
assert self.model_chunks > 1
self.checkpoint = checkpoint
nranks = group.size()
nstages = self.model_chunks * nranks
assert self.chunks % nranks == 0
assert self.chunks > (nstages + nranks - 2)
assert equal_across_ranks(batch_dim, group)
assert equal_across_ranks(chunks, group)
assert equal_across_ranks(self.model_chunks, group)
[docs] def forward(self, *inputs):
"""
在 ``torch.no_grad()`` 下做一次 forward
Args:
*inputs: 模型的输入;如果不是第一个 stage,可以不用传入
Returns:
最后一个 stage 会返回模型的输出,其他 stage 返回 ``None``
"""
assert not torch.is_grad_enabled(), "forward 只有在 no_grad 模式的时候可以用"
rank = self.group.rank()
nranks = self.group.size()
next_rank = (rank + 1) % nranks
prev_rank = (rank + nranks - 1) % nranks
nstages = self.model_chunks * nranks
for i in range(self.model_chunks):
stage = i * nranks + rank
if stage != 0:
inputs = comm.recv(prev_rank, self.group)
outputs, _ = self._forward_chunk(stage, i, inputs)
if stage != nstages - 1:
comm.send(outputs, next_rank, self.group)
outputs = None
if isinstance(outputs, (tuple, list)) and len(outputs) == 1:
outputs = outputs[0]
return outputs
[docs] def forward_backward(self, *inputs, criterion=None, labels=(), return_outputs=False):
"""
做一次 forward 和 backward,返回每个 microbatch 的 loss
Args:
*inputs: 模型的输入;如果不是第一个 stage,可以不用传入
criterion (Callable): 损失函数,通过 ``criterion(*outputs, *labels)`` 的方式调用;如果不是最后一个 stage,可以不用传入
labels (tuple): 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入
return_outputs (bool): 是否在最后一个 stage 返回模型的输出,设置为 ``True`` 的时候会多占用一部分显存;默认是 ``False``
Returns:
tuple: ``(loss, outputs)``
NOTE:
如果设置 ``return_outputs = True``,``outputs`` 是模型最后一个 stage 的输出,否则为 ``None``。
只有最后一个 stage 会有返回值,其他 stage 上会返回 ``(None, None)``。
"""
rank = self.group.rank()
nranks = self.group.size()
nstages = self.model_chunks * nranks
if rank == nranks - 1:
assert criterion is not None
input_chunks, output_chunks, losses = defaultdict(list), defaultdict(list), defaultdict(list)
if rank == 0:
input_chunks[0] = microbatch.scatter(inputs, self.chunks, self.batch_dim)
if rank == nranks - 1:
labels = microbatch.scatter(labels, self.chunks, self.batch_dim)
else:
labels = [() for _ in range(self.chunks)]
#####################
# step 1: warmup
#####################
num_warmup = (self.model_chunks - 1) * nranks + 2 * (nranks - rank - 1)
output = None
for i in range(num_warmup):
model_id, stage, microbatch_id = self._get_forward_id(i)
recv_input = (stage != 0)
input, _ = self._communicate(recv_input=recv_input, recv_grad_output=False, output=output)
# release output
if not return_outputs or stage != nstages - 1:
free_outputs(output)
# recv input
assert input is not None or stage == 0
if stage != 0:
input_chunks[model_id].append(input)
else:
input = input_chunks[model_id][microbatch_id]
# prepare label
label = ()
if stage == nstages - 1:
label = labels[microbatch_id]
with sync_forward(self.module[model_id], microbatch_id):
output, loss = self._forward_chunk(stage, model_id, input, criterion, label)
output_chunks[model_id].append(output)
losses[model_id].append(loss)
output = None if stage == nstages - 1 else output
#####################
# step 2: 1F1B
#####################
num_1F1B = self.model_chunks * self.chunks - num_warmup
grad_input = None
for i in range(num_1F1B):
model_id, stage, microbatch_id = self._get_forward_id(i + num_warmup)
bwd_stage = self._get_backward_id(i)[1]
recv_input = (stage != 0)
recv_grad_output = (bwd_stage != nstages - 1)
input, grad_output = self._communicate(recv_input=recv_input, recv_grad_output=recv_grad_output,
grad_input=grad_input, output=output)
# release output
if not return_outputs or stage != nstages - 1:
free_outputs(output)
assert input is not None or stage == 0
if stage != 0:
input_chunks[model_id].append(input)
else:
input = input_chunks[model_id][microbatch_id]
# prepare label
label = ()
if stage == nstages - 1:
label = labels[microbatch_id]
# forward
with sync_forward(self.module[model_id], microbatch_id):
output, loss = self._forward_chunk(stage, model_id, input, criterion, label)
output_chunks[model_id].append(output)
losses[model_id].append(loss)
output = None if stage == nstages - 1 else output
# backward
model_id, stage, microbatch_id = self._get_backward_id(i)
bwd_input = input_chunks[model_id][microbatch_id]
bwd_output = output_chunks[model_id][microbatch_id]
bwd_loss = losses[model_id][microbatch_id]
with sync_backward(self.module[model_id], microbatch_id == self.chunks - 1, bwd_output):
grad_input = self._backward_chunk(stage, bwd_output, grad_output, bwd_input, bwd_loss)
if bwd_loss is not None:
losses[model_id][microbatch_id] = bwd_loss.detach()
input_chunks[model_id][microbatch_id] = None
if not return_outputs or stage != nstages - 1:
output_chunks[model_id][microbatch_id] = None
#####################
# step 3: cooldown
#####################
num_cooldown = num_warmup
for i in range(num_cooldown):
model_id, stage, microbatch_id = self._get_backward_id(i + num_1F1B)
recv_grad_output = (stage != nstages - 1)
_, grad_output = self._communicate(recv_input=False, recv_grad_output=recv_grad_output,
grad_input=grad_input, output=output)
output = None
bwd_output = output_chunks[model_id][microbatch_id]
loss = losses[model_id][microbatch_id]
input = input_chunks[model_id][microbatch_id]
with sync_backward(self.module[model_id], microbatch_id == self.chunks - 1, bwd_output):
grad_input = self._backward_chunk(stage, bwd_output, grad_output, input, loss)
if loss is not None:
losses[model_id][microbatch_id] = loss.detach()
input_chunks[model_id][microbatch_id] = None
if not return_outputs or stage != nstages - 1:
output_chunks[model_id][microbatch_id] = None
self._communicate(recv_input=False, recv_grad_output=False, grad_input=grad_input)
loss, output = None, None
if rank == nranks - 1:
loss = torch.stack(losses[self.model_chunks - 1], 0).sum().detach()
if return_outputs:
output = microbatch.gather(output_chunks[self.model_chunks - 1], self.batch_dim)
if len(output) == 1:
output = output[0]
return loss, output
def _communicate(
self,
recv_input=True,
recv_grad_output=True,
grad_input=None,
output=None,
):
rank = self.group.rank()
nranks = self.group.size()
next_rank = (rank + 1) % nranks
prev_rank = (rank + nranks - 1) % nranks
input, grad_output = comm.sendrecv_twice(
grad_input, recv_input, prev_rank,
output, recv_grad_output, next_rank,
self.group)
return input, grad_output
def _get_forward_id(self, i):
"""
nranks = 4, num_model_chunks = 2
model_id: 0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1 ...
microbatch_id: 0 1 2 3 0 1 2 3 4 5 6 7 4 5 6 7 ...
stage = model_id * nranks + rank
"""
rank = self.group.rank()
nranks = self.group.size()
model_id = (i // nranks) % self.model_chunks
stage = model_id * nranks + rank
k, m = divmod(i, nranks * self.model_chunks)
microbatch_id = k * nranks + (m % nranks)
return model_id, stage, microbatch_id
def _get_backward_id(self, i):
"""
nranks = 4, num_model_chunks = 2
model_id: 1 1 1 1 0 0 0 0 1 1 1 1 0 0 0 0 ...
microbatch_id: 0 1 2 3 0 1 2 3 4 5 6 7 4 5 6 7 ...
stage = model_id * nranks + rank
"""
rank = self.group.rank()
nranks = self.group.size()
model_id = (i // nranks) % self.model_chunks
model_id = self.model_chunks - 1 - model_id
stage = model_id * nranks + rank # model_id -> stage
k, m = divmod(i, nranks * self.model_chunks)
microbatch_id = k * nranks + (m % nranks)
return model_id, stage, microbatch_id
@cuda_timer.record("forward_chunk")
def _forward_chunk(self, stage, model_id, input, criterion=None, label=()):
nstages = self.model_chunks * self.group.size()
if self.checkpoint and torch.is_grad_enabled():
output = checkpoint(self.module[model_id], *input)
else:
output = self.module[model_id](*input)
output = [output] if isinstance(output, torch.Tensor) else output
loss = None
if stage == nstages - 1 and criterion is not None:
loss = criterion(*output, *label)
return output, loss
@cuda_timer.record("backward_chunk")
def _backward_chunk(self, stage, output, grad_output, input, loss):
nstages = self.model_chunks * self.group.size()
if stage == nstages - 1:
loss.backward()
else:
# 过滤掉 None
arr = [(o, g) for o, g in zip(output, grad_output) if g is not None]
output, grad_output = list(zip(*arr))
# output.data 可能已经被释放了,不能直接调用 torch.autograd.backward
# run_backward 的时候只需要 output.grad_fn
if len(output) > 0:
run_backward(output, grad_output)
grad_input = None
if stage != 0:
grad_input = [x.grad for x in input]
return grad_input
def free_outputs(tensors):
if tensors is None:
return
for x in tensors:
x.data = torch.empty(1, device=x.device, dtype=x.dtype)