import torch
import torch.nn as nn
import torch.distributed as dist
from . import comm, microbatch
from .utils import checkpoint, sync_forward, sync_backward, run_backward
from ..timer import cuda_timer
"""
PipDream: non-interleaved 1F1B schedule
Algo:
stage 0: F0 F1 F2 F3 B0 F4 B1 F5 B2 B3 B4 B5
stage 1: F0 F1 F2 B0 F3 B1 F4 B2 F5 B3 B4 B5
stage 2: F0 F1 B0 F2 B1 F3 B2 F4 B3 F5 B4 B5
stage 3: F0 B0 F1 B1 F2 B2 F3 B3 F4 B4 F5 B5
<---------> <---------------------------> <--------->
warmup 1B1F cooldown
step 1: warmup
F 开始前 recv(input, prev_rank)
F 结束后 send(output, next_rank),不包括最后一个 F
step 2: 1B1F (one backward followed by one forward)
开始前:sendrecv(output, grad_output, next_rank)
B 结束后 sendrecv(grad_input, input, prev_rank)
F 结束后 sendrecv(output, grad_output, next_rank)
step 3: cooldown
B 开始前 recv(grad_output, next_rank),不包括第一个 B
B 结束后 send(grad_input, prev_rank)
stage 的数量: n
microbatch 的数量: m
在单 GPU 上运行所需的显存: M
在单 GPU 上运行所需的时间: T
bubble: (n - 1) / m * T
每个 stage 所需的显存: n / m * M
"""
[docs]class PipeDream(nn.Module):
"""
分布式流水线并行算法 PipeDream (non-interleaved 1F1B)
论文:《 PipeDream: Fast and Efficient Pipeline Parallel DNN Training 》
.. code-block::
stage 0: F0 F1 F2 F3 B0 F4 B1 F5 B2 B3 B4 B5
stage 1: F0 F1 F2 B0 F3 B1 F4 B2 F5 B3 B4 B5
stage 2: F0 F1 B0 F2 B1 F3 B2 F4 B3 F5 B4 B5
stage 3: F0 B0 F1 B1 F2 B2 F3 B3 F4 B4 F5 B5
<---------> <---------------------------> <--------->
warmup 1B1F cooldown
相比于 :class:`GPipe`, 两者的训练速度差不多,但 PipeDream 能够减少显存的使用,显存占用是 GPipe 的 ``min(1, ngpus / chunks)``,
因此推荐优先使用 PipeDream。
传入 PipeDream 的模型可以有多个输入和输出,但相邻 stage 模型之间的输入输出要能对应上。
Args:
module (nn.Module): 模型的一部分,可以通过 :func:`partition` 分割得到
chunks (int): microbatch 的数量
process_group (dist.ProcessGroup): 流水线并行的进程组,默认使用全局的进程组
batch_dim (int): batch 的维度,默认是 ``0``
checkpoint (bool): 是否使用 activation checkpoint,默认是 ``False``
Example:
训练的时候需要调用 ``forward_backward`` 并传入损失函数 ``criterion`` 和标签 ``labels``:
.. code-block:: python
from haiscale.pipeline import PipeDream, partition
dist.init_process_group(...)
torch.cuda.set_device(local_rank)
rank, world_size = dist.get_rank(), dist.get_world_size()
torch.manual_seed(12345)
def loss_fn(out, y):
return ((out - y)**2).sum()
model = nn.Sequential(...).cuda()
model = partition(model, rank, world_size)
model = PipeDream(model, chunks=32)
for x, y in dataloader:
loss, _ = model.forward_backward(x, criterion=loss_fn, labels=(y,))
# eval
with torch.no_grad():
out = model(x)
if rank == world_size - 1:
# calculate metrics ...
NOTE:
示例中的第一个 stage 和最后一个 stage 的 dataloader 应该要能产生相同的数据。
NOTE:
如果要和 DDP 一起使用,需要把分割后的模型先传入 DDP,然后再把 DDP 模型传入 GPipe。
通过 :func:`make_subgroups` 可以设置 pipeline parallel 和 data parallel 的大小。
"""
def __init__(self, module, chunks, process_group=None, batch_dim=0, checkpoint=False):
super().__init__()
assert isinstance(module, nn.Module)
assert isinstance(chunks, int)
assert isinstance(batch_dim, int)
assert isinstance(checkpoint, bool)
group = process_group
assert group is None or isinstance(group, dist.ProcessGroup)
group = group or dist.distributed_c10d._get_default_group()
device = torch.cuda.current_device()
device = torch.device(device)
assert next(module.parameters()).device == device
self.device = device
self.module = module
self.chunks = chunks
self.group = group
self.batch_dim = batch_dim
self.checkpoint = checkpoint
# stream for send
self.send_stream = torch.cuda.Stream(device)
[docs] def stage(self):
"""
返回当前的 stage
"""
return self.group.rank()
[docs] def num_stages(self):
"""
返回 stage 的数量
"""
return self.group.size()
[docs] def forward(self, *inputs, chunks=None, criterion=None, labels=(), return_outputs=False):
"""
在 ``torch.no_grad()`` 下做一次 forward
Args:
*inputs: 模型的输入;如果不是第一个 stage,可以不用传入
criterion (Callable): 损失函数,通过 ``criterion(*outputs, *labels)`` 的方式调用;如果不是最后一个 stage,可以不用传入
labels (tuple): 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入
return_outputs (bool): 是否在最后一个 stage 返回模型的输出,设置为 ``True`` 的时候会多占用一部分显存;默认是 ``False``
Returns:
最后一个 stage 会返回 ``(loss, outputs)``,其他 stage 返回 ``(None, None)``
"""
assert not torch.is_grad_enabled(), "forward 只有在 no_grad 模式的时候可以用"
rank = self.group.rank()
nranks = self.group.size()
next_rank = (rank + 1) % nranks
prev_rank = (rank + nranks - 1) % nranks
is_first_stage = (rank == 0)
is_last_stage = (rank == nranks - 1)
chunks = chunks or self.chunks
input_chunks, output_chunks, losses = [], [], []
if is_first_stage:
input_chunks = microbatch.scatter(inputs, chunks, self.batch_dim)
if is_last_stage:
labels = microbatch.scatter(labels, chunks, self.batch_dim)
else:
labels = [() for _ in range(chunks)]
for i in range(chunks):
if rank != 0:
input = comm.recv(prev_rank, self.group)
else:
input = input_chunks[i]
output, loss = self._forward_chunk(input, criterion, labels[i])
if not is_last_stage:
self.send_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.send_stream):
comm.send(output, next_rank, self.group)
[x.data.record_stream(self.send_stream) for x in output]
else:
output_chunks.append(output)
losses.append(loss)
outputs, loss = None, None
if is_last_stage:
if criterion is not None:
loss = torch.stack(losses).float().sum().detach()
if return_outputs:
outputs = microbatch.gather(output_chunks, self.batch_dim)
if isinstance(outputs, (tuple, list)) and len(outputs) == 1:
outputs = outputs[0]
return loss, outputs
[docs] def forward_backward(self, *inputs, criterion=None, labels=(), return_outputs=False):
"""
做一次 forward 和 backward,返回 loss
Args:
*inputs: 模型的输入;如果不是第一个 stage,可以不用传入
criterion (Callable): 损失函数,通过 ``criterion(*outputs, *labels)`` 的方式调用;如果不是最后一个 stage,可以不用传入
labels (tuple): 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入
return_outputs (bool): 是否在最后一个 stage 返回模型的输出,设置为 ``True`` 的时候会多占用一部分显存;默认是 ``False``
Returns:
tuple: ``(loss, outputs)``
NOTE:
如果设置 ``return_outputs = True``,``outputs`` 是模型最后一个 stage 的输出,否则为 ``None``。
只有最后一个 stage 会有返回值,其他 stage 上会返回 ``(None, None)``。
"""
rank = self.group.rank()
nranks = self.group.size()
next_rank = (rank + 1) % nranks
prev_rank = (rank + nranks - 1) % nranks
is_first_stage = (rank == 0)
is_last_stage = (rank == nranks - 1)
# 当 microbatch 的数量小于 stage 的数量时,退化成 GPipe
no_1B1F = (self.chunks < nranks)
if is_last_stage:
assert criterion is not None
input_chunks, output_chunks, losses = [], [], []
if is_first_stage:
input_chunks = microbatch.scatter(inputs, self.chunks, self.batch_dim)
if is_last_stage:
labels = microbatch.scatter(labels, self.chunks, self.batch_dim)
else:
labels = [() for _ in range(self.chunks)]
#####################
# step 1: warmup
#####################
num_warmup = self.chunks if no_1B1F else min(self.chunks, nranks - rank)
for i in range(num_warmup):
if is_first_stage:
input = input_chunks[i]
else:
input = comm.recv(prev_rank, self.group)
input_chunks.append(input)
# 第一个 forward 需要 broadcast buffer
with sync_forward(self.module, i == 0):
output, loss = self._forward_chunk(input, criterion, labels[i])
output_chunks.append(output)
losses.append(loss)
do_sendrecv = (not no_1B1F and i == num_warmup - 1)
if not do_sendrecv and not is_last_stage:
self.send_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.send_stream):
comm.send(output, next_rank, self.group)
[x.data.record_stream(self.send_stream) for x in output]
# send output, recv grad
if do_sendrecv and not is_last_stage:
grad_output = comm.sendrecv(output, next_rank, self.group)
else:
grad_output = None
# forward 结束之后可以释放掉 ouputs
# 我们只需要 ouptut.grad_fn 来做 backward
if not return_outputs or not is_last_stage:
free_outputs(output)
#####################
# step 2: 1B1F
#####################
num_1B1F = self.chunks - num_warmup
for i in range(num_1B1F):
# backward
with sync_backward(self.module, False, output_chunks[i]):
grad_input = self._backward_chunk(output_chunks[i], grad_output, input_chunks[i], losses[i])
input_chunks[i] = None
if not return_outputs or not is_last_stage:
output_chunks[i] = None
if is_last_stage:
losses[i] = losses[i].detach()
# send grad, recv input
if rank != 0:
input = comm.sendrecv(grad_input, prev_rank, self.group)
input_chunks.append(input)
else:
input = input_chunks[i + num_warmup]
# forward
with sync_forward(self.module, False):
output, loss = self._forward_chunk(input, criterion, labels[i + num_warmup])
output_chunks.append(output)
losses.append(loss)
# send output, recv grad
if not is_last_stage:
grad_output = comm.sendrecv(output, next_rank, self.group)
else:
grad_output = None
if not return_outputs or not is_last_stage:
free_outputs(output)
#####################
# step 3: cooldown
#####################
num_cooldown = num_warmup
for i in range(num_cooldown):
do_recv = (no_1B1F or i != 0)
if do_recv and not is_last_stage:
grad_output = comm.recv(next_rank, self.group)
chunk_idx = num_1B1F + i
# 最后一个 backward 需要 allreduce gradients
with sync_backward(self.module, chunk_idx == self.chunks - 1, output_chunks[chunk_idx]):
grad_input = self._backward_chunk(output_chunks[chunk_idx], grad_output, input_chunks[chunk_idx], losses[chunk_idx])
input_chunks[chunk_idx] = None
if not return_outputs or not is_last_stage:
output_chunks[chunk_idx] = None
if is_last_stage:
losses[i] = losses[i].detach()
if not is_first_stage:
self.send_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.send_stream):
comm.send(grad_input, prev_rank, self.group)
[x.data.record_stream(self.send_stream) for x in grad_input]
loss, output = None, None
if is_last_stage:
loss = torch.stack(losses).float().sum().detach()
if return_outputs:
output = microbatch.gather(output_chunks, self.batch_dim)
if len(output) == 1:
output = output[0]
return loss, output
@cuda_timer.record("forward_chunk")
def _forward_chunk(self, input, criterion=None, label=()):
rank = self.group.rank()
nranks = self.group.size()
is_last_stage = (rank == nranks - 1)
if self.checkpoint and torch.is_grad_enabled() and (not is_last_stage):
output = checkpoint(self.module, *input)
else:
output = self.module(*input)
output = [output] if isinstance(output, torch.Tensor) else output
loss = None
if is_last_stage and criterion is not None:
loss = criterion(*output, *label)
return output, loss
@cuda_timer.record("backward_chunk")
def _backward_chunk(self, output, grad_output, input, loss):
rank = self.group.rank()
nranks = self.group.size()
if rank == nranks - 1:
assert loss is not None
loss.backward()
else:
# 过滤掉 None
arr = [(o, g) for o, g in zip(output, grad_output) if g is not None]
output, grad_output = list(zip(*arr))
# output.data 可能已经被释放了,不能直接调用 torch.autograd.backward
# run_backward 的时候只需要 output.grad_fn
if len(output) > 0:
run_backward(output, grad_output)
grad_input = [x.grad for x in input]
return grad_input
def free_outputs(tensors):
for x in tensors:
x.data = torch.empty(1, device=x.device, dtype=x.dtype)