Shortcuts

Source code for haiscale.pipeline.gpipe

import traceback
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.utils.checkpoint

from . import comm, microbatch
from .utils import checkpoint, sync_forward, sync_backward


"""
GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism

Algo:

    stage 0:  F0 F1 F2 F3 F4 F5                   B5 B4 B3 B2 B1 B0
    stage 1:     F0 F1 F2 F3 F4 F5             B5 B4 B3 B2 B1 B0
    stage 2:        F0 F1 F2 F3 F4 F5       B5 B4 B3 B2 B1 B0
    stage 3:           F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0
              <------------------------> <------------------------>
                       forward                   backward

stage 的数量:               n
microbatch 的数量:          m
在单 GPU 上运行所需的显存:  M
在单 GPU 上运行所需的时间:  T

bubble:                     (n - 1) / m * T
每个 stage 所需的显存:      1 / n * M

"""


class BackwardState():

    def __init__(self, inputs, outputs, detached_out, losses=None) -> None:
        self.inputs = inputs
        self.outputs = outputs
        self.detached_out = detached_out
        self.grad_outputs = []
        self.handles = []
        self.queued_callback = False
        self.losses = losses


[docs]class GPipe(nn.Module): """ 分布式流水线并行算法 GPipe 论文:《 GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism 》 .. code-block:: stage 0: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0 stage 1: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0 stage 2: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0 stage 3: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0 <------------------------> <------------------------> forward backward 传入 GPipe 的模型可以有多个输入和输出,但相邻 stage 模型之间的输入输出要能对应上。 Args: module (nn.Module): 模型的一部分,可以通过 :func:`partition` 分割得到 chunks (int): microbatch 的数量 process_group (dist.ProcessGroup): 流水线并行的进程组,默认使用全局的进程组 batch_dim (int): batch 的维度,默认是 ``0`` checkpoint (str): ``never``, ``except_last`` 或者 ``always``,默认是 ``never`` broadcast_outputs (bool): 是否把最后一个 stage 的输出广播到其他 stage 上,作为每个 stage 上 GPipe 实例的输出; 设置 ``broadcast_outputs = False`` 可以节省一些通讯和显存上的开销,如果最后一个 stage 的输出比较大, 推荐设置为 ``False``;默认是 ``True`` Example: GPipe 有两种使用方式,第一种方式需要设置 ``broadcast_outputs = True``, 使用上类似于 DDP,每个 stage 返回的输出是相同的, 都可以用来计算 loss,但实际上只有最后一个 stage 的 loss 会用来做 backward: .. code-block:: python from haiscale.pipeline import GPipe, partition dist.init_process_group(...) torch.cuda.set_device(local_rank) rank, world_size = dist.get_rank(), dist.get_world_size() torch.manual_seed(12345) model = nn.Sequential(...).cuda() model = partition(model, rank, world_size) model = GPipe(model, chunks=64) for x, y in dataloader: out = model(x) loss = ((out - y)**2).sum() loss.backward() 第二种使用方法需要传入损失函数 ``criterion``,在 ``forward_backward`` 里会通过 ``criterion(*outputs, *labels)`` 的方式调用来计算 loss: .. code-block:: python from haiscale.pipeline import GPipe, partition dist.init_process_group(...) torch.cuda.set_device(local_rank) rank, world_size = dist.get_rank(), dist.get_world_size() torch.manual_seed(12345) def loss_fn(out, y): return ((out - y)**2).sum() model = nn.Sequential(...).cuda() model = partition(model, rank, world_size) model = GPipe(model, chunks=32) for x, y in dataloader: loss, _ = model.forward_backward(x, criterion=loss_fn, labels=(y,)) NOTE: 示例中的第一个 stage 和最后一个 stage 的 dataloader 应该要能产生相同的数据。 NOTE: 如果要和 DDP 一起使用,需要把分割后的模型先传入 DDP,然后再把 DDP 模型传入 GPipe。 通过 :func:`make_subgroups` 可以设置 pipeline parallel 和 data parallel 的大小。 """ def __init__(self, module, chunks, process_group=None, batch_dim=0, checkpoint="never", broadcast_outputs=True): super().__init__() assert isinstance(module, nn.Module) assert isinstance(chunks, int) assert isinstance(batch_dim, int) group = process_group assert group is None or isinstance(group, dist.ProcessGroup) device = torch.cuda.current_device() device = torch.device(device) assert next(module.parameters()).device == device self.device = device self.module = module self.chunks = chunks group = group or dist.distributed_c10d._get_default_group() self.group = group self.batch_dim = batch_dim self.broadcast_outputs = broadcast_outputs self.checkpoint = checkpoint when_to_stop = {"never": 0, "always": chunks, "except_last": chunks - 1} self.checkpoint_stop = when_to_stop[checkpoint] # stream for send self.send_stream = torch.cuda.Stream(device)
[docs] def stage(self): """ 返回当前的 stage """ return self.group.rank()
[docs] def num_stages(self): """ 返回 stage 的数量 """ return self.group.size()
def forward(self, *inputs): input_chunks, output_chunks, losses = self.forward_impl(*inputs) rank = self.group.rank() nranks = self.group.size() output = tuple() if rank == nranks - 1: output = microbatch.gather(output_chunks, dim=self.batch_dim) if self.broadcast_outputs: # 广播 forward 的结果到每个 rank 上,这样看上去就和 DDP 使用方法比较类似 # 但也会使得多做了一些计算、通讯 output = comm.broadcast(output, nranks - 1, self.group) # loss.backward() 的时候只会计算到 output 的梯度 # 之后我们手动调用 torch.autograd.backward() 来做反向传播 fn = lambda x: x.detach().requires_grad_(x.requires_grad) output = list(map(fn, output)) # 注册一个钩子在 output 里,loss.backward() 时通过钩子调用 schedule_backward if torch.is_grad_enabled(): state = BackwardState(input_chunks, output_chunks, output, losses) for i, x in enumerate(output): if x.requires_grad: state.handles.append(x.register_hook(self.enqueue_callback)) self.state = state # forward 返回的是最后一个 stage 的输出 if len(output) == 1: return output[0] return output def enqueue_callback(self, *unused): # 通过 queue_callback 在 loss.backward() 结束时调用 if not self.state.queued_callback: torch.autograd.Variable._execution_engine.queue_callback(self.do_backward) self.state.queued_callback = True
[docs] def forward_backward(self, *inputs, criterion=None, labels=(), return_outputs=False): """ 做一次 forward 和 backward,返回每个 microbatch 的 loss Args: *inputs: 模型的输入;如果不是第一个 stage,可以不用传入 criterion (Callable): 损失函数,通过 ``criterion(*outputs, *labels)`` 的方式调用;如果不是最后一个 stage,可以不用传入 labels (tuple): 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入 return_outputs (bool): 是否在最后一个 stage 返回模型的输出,设置为 ``True`` 的时候会多占用一部分显存;默认是 ``False`` Returns: tuple: ``(loss, outputs)`` NOTE: 如果设置 ``return_outputs = True``,``outputs`` 是模型最后一个 stage 的输出,否则为 ``None``。 只有最后一个 stage 会有返回值,其他 stage 上会返回 ``(None, None)``。 """ rank = self.group.rank() nranks = self.group.size() if rank == nranks - 1: assert criterion is not None input_chunks, output_chunks, losses = self.forward_impl(*inputs, criterion=criterion, labels=labels) if torch.is_grad_enabled(): self.state = BackwardState(input_chunks, output_chunks, None, losses) self.do_backward() # 返回的是最后一个 stage 的输出 if rank == nranks - 1: with torch.no_grad(): loss = torch.stack(losses, 0).sum().detach() output = None if return_outputs: output = microbatch.gather(output_chunks, dim=self.batch_dim) if len(output) == 1: output = output[0] return loss, output return None, None
def forward_impl(self, *inputs, criterion=None, labels=()): rank = self.group.rank() nranks = self.group.size() next_rank = (rank + 1) % nranks prev_rank = (rank + nranks - 1) % nranks compute_loss = (criterion is not None and rank == nranks - 1) if compute_loss: labels = microbatch.scatter(labels, self.chunks, self.batch_dim) input_chunks = [] if rank == 0: input_chunks = microbatch.scatter(inputs, self.chunks, dim=self.batch_dim) output_chunks = [] losses = [] for i in range(self.chunks): if rank == 0: input = input_chunks[i] else: input = comm.recv(prev_rank, self.group) input_chunks.append(input) with sync_forward(self.module, i == 0): if i >= self.checkpoint_stop: output = self.module(*input) else: output = checkpoint(self.module, *input) assert isinstance(output, (tuple, torch.Tensor)) output = (output,) if isinstance(output, torch.Tensor) else output output_chunks.append(output) # compute loss if compute_loss: loss = criterion(*output, *labels[i]) assert loss.numel() == 1, "criterion's output should be a scalar" losses.append(loss) if rank != nranks - 1: self.send_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.send_stream): comm.send(output, next_rank, self.group) return input_chunks, output_chunks, losses def do_backward(self): try: state = self.state rank = self.group.rank() nranks = self.group.size() outputs = state.outputs if rank == nranks - 1: if len(state.losses) > 0: grad_outputs = [(torch.ones_like(x),) for x in state.losses] outputs = [(x,) for x in state.losses] else: grads = [out.grad for out in state.detached_out] grad_outputs = microbatch.scatter(grads, self.chunks, dim=self.batch_dim) else: grad_outputs = None self.schedule_backward(grad_outputs, state.inputs, outputs) for h in state.handles: h.remove() # remove handle to release memory self.state = None except Exception as e: traceback.print_exc() raise e def schedule_backward(self, grad_outputs, input_chunks, output_chunks): rank = self.group.rank() nranks = self.group.size() next_rank = (rank + 1) % nranks prev_rank = (rank + nranks - 1) % nranks for i in reversed(range(self.chunks)): if rank == nranks - 1: grad_output = grad_outputs[i] else: grad_output = comm.recv(next_rank, self.group) input = input_chunks[i] output = output_chunks[i] # 过滤掉没有梯度或者不需要梯度的输出 filtered_output = [] filtered_grad_output = [] for grad, out in zip(grad_output, output): if grad is not None and out.requires_grad: filtered_output.append(out) filtered_grad_output.append(grad) with sync_backward(self.module, i == 0, filtered_output): torch.autograd.backward(filtered_output, filtered_grad_output) if rank != 0: self.send_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.send_stream): input_grads = tuple(x.grad for x in input) comm.send(input_grads, prev_rank, self.group) def backward(self): assert torch.is_grad_enabled() assert self.group.rank() != self.group.size() - 1 self.do_backward()