import traceback
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.utils.checkpoint
from . import comm, microbatch
from .utils import checkpoint, sync_forward, sync_backward
"""
GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism
Algo:
stage 0: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0
stage 1: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0
stage 2: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0
stage 3: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0
<------------------------> <------------------------>
forward backward
stage 的数量: n
microbatch 的数量: m
在单 GPU 上运行所需的显存: M
在单 GPU 上运行所需的时间: T
bubble: (n - 1) / m * T
每个 stage 所需的显存: 1 / n * M
"""
class BackwardState():
def __init__(self, inputs, outputs, detached_out, losses=None) -> None:
self.inputs = inputs
self.outputs = outputs
self.detached_out = detached_out
self.grad_outputs = []
self.handles = []
self.queued_callback = False
self.losses = losses
[docs]class GPipe(nn.Module):
"""
分布式流水线并行算法 GPipe
论文:《 GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism 》
.. code-block::
stage 0: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0
stage 1: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0
stage 2: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0
stage 3: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0
<------------------------> <------------------------>
forward backward
传入 GPipe 的模型可以有多个输入和输出,但相邻 stage 模型之间的输入输出要能对应上。
Args:
module (nn.Module): 模型的一部分,可以通过 :func:`partition` 分割得到
chunks (int): microbatch 的数量
process_group (dist.ProcessGroup): 流水线并行的进程组,默认使用全局的进程组
batch_dim (int): batch 的维度,默认是 ``0``
checkpoint (str): ``never``, ``except_last`` 或者 ``always``,默认是 ``never``
broadcast_outputs (bool): 是否把最后一个 stage 的输出广播到其他 stage 上,作为每个 stage 上 GPipe 实例的输出;
设置 ``broadcast_outputs = False`` 可以节省一些通讯和显存上的开销,如果最后一个 stage 的输出比较大,
推荐设置为 ``False``;默认是 ``True``
Example:
GPipe 有两种使用方式,第一种方式需要设置 ``broadcast_outputs = True``, 使用上类似于 DDP,每个 stage 返回的输出是相同的,
都可以用来计算 loss,但实际上只有最后一个 stage 的 loss 会用来做 backward:
.. code-block:: python
from haiscale.pipeline import GPipe, partition
dist.init_process_group(...)
torch.cuda.set_device(local_rank)
rank, world_size = dist.get_rank(), dist.get_world_size()
torch.manual_seed(12345)
model = nn.Sequential(...).cuda()
model = partition(model, rank, world_size)
model = GPipe(model, chunks=64)
for x, y in dataloader:
out = model(x)
loss = ((out - y)**2).sum()
loss.backward()
第二种使用方法需要传入损失函数 ``criterion``,在 ``forward_backward`` 里会通过 ``criterion(*outputs, *labels)``
的方式调用来计算 loss:
.. code-block:: python
from haiscale.pipeline import GPipe, partition
dist.init_process_group(...)
torch.cuda.set_device(local_rank)
rank, world_size = dist.get_rank(), dist.get_world_size()
torch.manual_seed(12345)
def loss_fn(out, y):
return ((out - y)**2).sum()
model = nn.Sequential(...).cuda()
model = partition(model, rank, world_size)
model = GPipe(model, chunks=32)
for x, y in dataloader:
loss, _ = model.forward_backward(x, criterion=loss_fn, labels=(y,))
NOTE:
示例中的第一个 stage 和最后一个 stage 的 dataloader 应该要能产生相同的数据。
NOTE:
如果要和 DDP 一起使用,需要把分割后的模型先传入 DDP,然后再把 DDP 模型传入 GPipe。
通过 :func:`make_subgroups` 可以设置 pipeline parallel 和 data parallel 的大小。
"""
def __init__(self, module, chunks, process_group=None, batch_dim=0,
checkpoint="never", broadcast_outputs=True):
super().__init__()
assert isinstance(module, nn.Module)
assert isinstance(chunks, int)
assert isinstance(batch_dim, int)
group = process_group
assert group is None or isinstance(group, dist.ProcessGroup)
device = torch.cuda.current_device()
device = torch.device(device)
assert next(module.parameters()).device == device
self.device = device
self.module = module
self.chunks = chunks
group = group or dist.distributed_c10d._get_default_group()
self.group = group
self.batch_dim = batch_dim
self.broadcast_outputs = broadcast_outputs
self.checkpoint = checkpoint
when_to_stop = {"never": 0, "always": chunks, "except_last": chunks - 1}
self.checkpoint_stop = when_to_stop[checkpoint]
# stream for send
self.send_stream = torch.cuda.Stream(device)
[docs] def stage(self):
"""
返回当前的 stage
"""
return self.group.rank()
[docs] def num_stages(self):
"""
返回 stage 的数量
"""
return self.group.size()
def forward(self, *inputs):
input_chunks, output_chunks, losses = self.forward_impl(*inputs)
rank = self.group.rank()
nranks = self.group.size()
output = tuple()
if rank == nranks - 1:
output = microbatch.gather(output_chunks, dim=self.batch_dim)
if self.broadcast_outputs:
# 广播 forward 的结果到每个 rank 上,这样看上去就和 DDP 使用方法比较类似
# 但也会使得多做了一些计算、通讯
output = comm.broadcast(output, nranks - 1, self.group)
# loss.backward() 的时候只会计算到 output 的梯度
# 之后我们手动调用 torch.autograd.backward() 来做反向传播
fn = lambda x: x.detach().requires_grad_(x.requires_grad)
output = list(map(fn, output))
# 注册一个钩子在 output 里,loss.backward() 时通过钩子调用 schedule_backward
if torch.is_grad_enabled():
state = BackwardState(input_chunks, output_chunks, output, losses)
for i, x in enumerate(output):
if x.requires_grad:
state.handles.append(x.register_hook(self.enqueue_callback))
self.state = state
# forward 返回的是最后一个 stage 的输出
if len(output) == 1:
return output[0]
return output
def enqueue_callback(self, *unused):
# 通过 queue_callback 在 loss.backward() 结束时调用
if not self.state.queued_callback:
torch.autograd.Variable._execution_engine.queue_callback(self.do_backward)
self.state.queued_callback = True
[docs] def forward_backward(self, *inputs, criterion=None, labels=(), return_outputs=False):
"""
做一次 forward 和 backward,返回每个 microbatch 的 loss
Args:
*inputs: 模型的输入;如果不是第一个 stage,可以不用传入
criterion (Callable): 损失函数,通过 ``criterion(*outputs, *labels)`` 的方式调用;如果不是最后一个 stage,可以不用传入
labels (tuple): 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入
return_outputs (bool): 是否在最后一个 stage 返回模型的输出,设置为 ``True`` 的时候会多占用一部分显存;默认是 ``False``
Returns:
tuple: ``(loss, outputs)``
NOTE:
如果设置 ``return_outputs = True``,``outputs`` 是模型最后一个 stage 的输出,否则为 ``None``。
只有最后一个 stage 会有返回值,其他 stage 上会返回 ``(None, None)``。
"""
rank = self.group.rank()
nranks = self.group.size()
if rank == nranks - 1:
assert criterion is not None
input_chunks, output_chunks, losses = self.forward_impl(*inputs, criterion=criterion, labels=labels)
if torch.is_grad_enabled():
self.state = BackwardState(input_chunks, output_chunks, None, losses)
self.do_backward()
# 返回的是最后一个 stage 的输出
if rank == nranks - 1:
with torch.no_grad():
loss = torch.stack(losses, 0).sum().detach()
output = None
if return_outputs:
output = microbatch.gather(output_chunks, dim=self.batch_dim)
if len(output) == 1:
output = output[0]
return loss, output
return None, None
def forward_impl(self, *inputs, criterion=None, labels=()):
rank = self.group.rank()
nranks = self.group.size()
next_rank = (rank + 1) % nranks
prev_rank = (rank + nranks - 1) % nranks
compute_loss = (criterion is not None and rank == nranks - 1)
if compute_loss:
labels = microbatch.scatter(labels, self.chunks, self.batch_dim)
input_chunks = []
if rank == 0:
input_chunks = microbatch.scatter(inputs, self.chunks, dim=self.batch_dim)
output_chunks = []
losses = []
for i in range(self.chunks):
if rank == 0:
input = input_chunks[i]
else:
input = comm.recv(prev_rank, self.group)
input_chunks.append(input)
with sync_forward(self.module, i == 0):
if i >= self.checkpoint_stop:
output = self.module(*input)
else:
output = checkpoint(self.module, *input)
assert isinstance(output, (tuple, torch.Tensor))
output = (output,) if isinstance(output, torch.Tensor) else output
output_chunks.append(output)
# compute loss
if compute_loss:
loss = criterion(*output, *labels[i])
assert loss.numel() == 1, "criterion's output should be a scalar"
losses.append(loss)
if rank != nranks - 1:
self.send_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.send_stream):
comm.send(output, next_rank, self.group)
return input_chunks, output_chunks, losses
def do_backward(self):
try:
state = self.state
rank = self.group.rank()
nranks = self.group.size()
outputs = state.outputs
if rank == nranks - 1:
if len(state.losses) > 0:
grad_outputs = [(torch.ones_like(x),) for x in state.losses]
outputs = [(x,) for x in state.losses]
else:
grads = [out.grad for out in state.detached_out]
grad_outputs = microbatch.scatter(grads, self.chunks, dim=self.batch_dim)
else:
grad_outputs = None
self.schedule_backward(grad_outputs, state.inputs, outputs)
for h in state.handles:
h.remove() # remove handle to release memory
self.state = None
except Exception as e:
traceback.print_exc()
raise e
def schedule_backward(self, grad_outputs, input_chunks, output_chunks):
rank = self.group.rank()
nranks = self.group.size()
next_rank = (rank + 1) % nranks
prev_rank = (rank + nranks - 1) % nranks
for i in reversed(range(self.chunks)):
if rank == nranks - 1:
grad_output = grad_outputs[i]
else:
grad_output = comm.recv(next_rank, self.group)
input = input_chunks[i]
output = output_chunks[i]
# 过滤掉没有梯度或者不需要梯度的输出
filtered_output = []
filtered_grad_output = []
for grad, out in zip(grad_output, output):
if grad is not None and out.requires_grad:
filtered_output.append(out)
filtered_grad_output.append(grad)
with sync_backward(self.module, i == 0, filtered_output):
torch.autograd.backward(filtered_output, filtered_grad_output)
if rank != 0:
self.send_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.send_stream):
input_grads = tuple(x.grad for x in input)
comm.send(input_grads, prev_rank, self.group)
def backward(self):
assert torch.is_grad_enabled()
assert self.group.rank() != self.group.size() - 1
self.do_backward()