Shortcuts

Source code for haiscale.pipeline.pipedream

import torch
import torch.nn as nn
import torch.distributed as dist

from . import comm, microbatch
from .utils import checkpoint, sync_forward, sync_backward, run_backward
from ..timer import cuda_timer


"""
PipDream: non-interleaved 1F1B schedule


Algo:

    stage 0:  F0 F1 F2 F3          B0 F4 B1 F5          B2 B3 B4 B5
    stage 1:     F0 F1 F2       B0 F3 B1 F4 B2 F5       B3 B4 B5
    stage 2:        F0 F1    B0 F2 B1 F3 B2 F4 B3 F5    B4 B5
    stage 3:           F0 B0 F1 B1 F2 B2 F3 B3 F4 B4 F5 B5
              <---------> <---------------------------> <--------->
                warmup               1B1F                cooldown

    step 1: warmup

        F 开始前 recv(input, prev_rank)
        F 结束后 send(output, next_rank),不包括最后一个 F

    step 2: 1B1F (one backward followed by one forward)

        开始前:sendrecv(output, grad_output, next_rank)

        B 结束后 sendrecv(grad_input, input, prev_rank)
        F 结束后 sendrecv(output, grad_output, next_rank)

    step 3: cooldown

        B 开始前 recv(grad_output, next_rank),不包括第一个 B
        B 结束后 send(grad_input, prev_rank)


stage 的数量:               n
microbatch 的数量:          m
在单 GPU 上运行所需的显存:  M
在单 GPU 上运行所需的时间:  T

bubble:                   (n - 1) / m * T
每个 stage 所需的显存:     n / m * M

"""


[docs]class PipeDream(nn.Module): """ 分布式流水线并行算法 PipeDream (non-interleaved 1F1B) 论文:《 PipeDream: Fast and Efficient Pipeline Parallel DNN Training 》 .. code-block:: stage 0: F0 F1 F2 F3 B0 F4 B1 F5 B2 B3 B4 B5 stage 1: F0 F1 F2 B0 F3 B1 F4 B2 F5 B3 B4 B5 stage 2: F0 F1 B0 F2 B1 F3 B2 F4 B3 F5 B4 B5 stage 3: F0 B0 F1 B1 F2 B2 F3 B3 F4 B4 F5 B5 <---------> <---------------------------> <---------> warmup 1B1F cooldown 相比于 :class:`GPipe`, 两者的训练速度差不多,但 PipeDream 能够减少显存的使用,显存占用是 GPipe 的 ``min(1, ngpus / chunks)``, 因此推荐优先使用 PipeDream。 传入 PipeDream 的模型可以有多个输入和输出,但相邻 stage 模型之间的输入输出要能对应上。 Args: module (nn.Module): 模型的一部分,可以通过 :func:`partition` 分割得到 chunks (int): microbatch 的数量 process_group (dist.ProcessGroup): 流水线并行的进程组,默认使用全局的进程组 batch_dim (int): batch 的维度,默认是 ``0`` checkpoint (bool): 是否使用 activation checkpoint,默认是 ``False`` Example: 训练的时候需要调用 ``forward_backward`` 并传入损失函数 ``criterion`` 和标签 ``labels``: .. code-block:: python from haiscale.pipeline import PipeDream, partition dist.init_process_group(...) torch.cuda.set_device(local_rank) rank, world_size = dist.get_rank(), dist.get_world_size() torch.manual_seed(12345) def loss_fn(out, y): return ((out - y)**2).sum() model = nn.Sequential(...).cuda() model = partition(model, rank, world_size) model = PipeDream(model, chunks=32) for x, y in dataloader: loss, _ = model.forward_backward(x, criterion=loss_fn, labels=(y,)) # eval with torch.no_grad(): out = model(x) if rank == world_size - 1: # calculate metrics ... NOTE: 示例中的第一个 stage 和最后一个 stage 的 dataloader 应该要能产生相同的数据。 NOTE: 如果要和 DDP 一起使用,需要把分割后的模型先传入 DDP,然后再把 DDP 模型传入 GPipe。 通过 :func:`make_subgroups` 可以设置 pipeline parallel 和 data parallel 的大小。 """ def __init__(self, module, chunks, process_group=None, batch_dim=0, checkpoint=False): super().__init__() assert isinstance(module, nn.Module) assert isinstance(chunks, int) assert isinstance(batch_dim, int) assert isinstance(checkpoint, bool) group = process_group assert group is None or isinstance(group, dist.ProcessGroup) group = group or dist.distributed_c10d._get_default_group() device = torch.cuda.current_device() device = torch.device(device) assert next(module.parameters()).device == device self.device = device self.module = module self.chunks = chunks self.group = group self.batch_dim = batch_dim self.checkpoint = checkpoint # stream for send self.send_stream = torch.cuda.Stream(device)
[docs] def stage(self): """ 返回当前的 stage """ return self.group.rank()
[docs] def num_stages(self): """ 返回 stage 的数量 """ return self.group.size()
[docs] def forward(self, *inputs, chunks=None, criterion=None, labels=(), return_outputs=False): """ 在 ``torch.no_grad()`` 下做一次 forward Args: *inputs: 模型的输入;如果不是第一个 stage,可以不用传入 criterion (Callable): 损失函数,通过 ``criterion(*outputs, *labels)`` 的方式调用;如果不是最后一个 stage,可以不用传入 labels (tuple): 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入 return_outputs (bool): 是否在最后一个 stage 返回模型的输出,设置为 ``True`` 的时候会多占用一部分显存;默认是 ``False`` Returns: 最后一个 stage 会返回 ``(loss, outputs)``,其他 stage 返回 ``(None, None)`` """ assert not torch.is_grad_enabled(), "forward 只有在 no_grad 模式的时候可以用" rank = self.group.rank() nranks = self.group.size() next_rank = (rank + 1) % nranks prev_rank = (rank + nranks - 1) % nranks is_first_stage = (rank == 0) is_last_stage = (rank == nranks - 1) chunks = chunks or self.chunks input_chunks, output_chunks, losses = [], [], [] if is_first_stage: input_chunks = microbatch.scatter(inputs, chunks, self.batch_dim) if is_last_stage: labels = microbatch.scatter(labels, chunks, self.batch_dim) else: labels = [() for _ in range(chunks)] for i in range(chunks): if rank != 0: input = comm.recv(prev_rank, self.group) else: input = input_chunks[i] output, loss = self._forward_chunk(input, criterion, labels[i]) if not is_last_stage: self.send_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.send_stream): comm.send(output, next_rank, self.group) [x.data.record_stream(self.send_stream) for x in output] else: output_chunks.append(output) losses.append(loss) outputs, loss = None, None if is_last_stage: if criterion is not None: loss = torch.stack(losses).float().sum().detach() if return_outputs: outputs = microbatch.gather(output_chunks, self.batch_dim) if isinstance(outputs, (tuple, list)) and len(outputs) == 1: outputs = outputs[0] return loss, outputs
[docs] def forward_backward(self, *inputs, criterion=None, labels=(), return_outputs=False): """ 做一次 forward 和 backward,返回 loss Args: *inputs: 模型的输入;如果不是第一个 stage,可以不用传入 criterion (Callable): 损失函数,通过 ``criterion(*outputs, *labels)`` 的方式调用;如果不是最后一个 stage,可以不用传入 labels (tuple): 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入 return_outputs (bool): 是否在最后一个 stage 返回模型的输出,设置为 ``True`` 的时候会多占用一部分显存;默认是 ``False`` Returns: tuple: ``(loss, outputs)`` NOTE: 如果设置 ``return_outputs = True``,``outputs`` 是模型最后一个 stage 的输出,否则为 ``None``。 只有最后一个 stage 会有返回值,其他 stage 上会返回 ``(None, None)``。 """ rank = self.group.rank() nranks = self.group.size() next_rank = (rank + 1) % nranks prev_rank = (rank + nranks - 1) % nranks is_first_stage = (rank == 0) is_last_stage = (rank == nranks - 1) # 当 microbatch 的数量小于 stage 的数量时,退化成 GPipe no_1B1F = (self.chunks < nranks) if is_last_stage: assert criterion is not None input_chunks, output_chunks, losses = [], [], [] if is_first_stage: input_chunks = microbatch.scatter(inputs, self.chunks, self.batch_dim) if is_last_stage: labels = microbatch.scatter(labels, self.chunks, self.batch_dim) else: labels = [() for _ in range(self.chunks)] ##################### # step 1: warmup ##################### num_warmup = self.chunks if no_1B1F else min(self.chunks, nranks - rank) for i in range(num_warmup): if is_first_stage: input = input_chunks[i] else: input = comm.recv(prev_rank, self.group) input_chunks.append(input) # 第一个 forward 需要 broadcast buffer with sync_forward(self.module, i == 0): output, loss = self._forward_chunk(input, criterion, labels[i]) output_chunks.append(output) losses.append(loss) do_sendrecv = (not no_1B1F and i == num_warmup - 1) if not do_sendrecv and not is_last_stage: self.send_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.send_stream): comm.send(output, next_rank, self.group) [x.data.record_stream(self.send_stream) for x in output] # send output, recv grad if do_sendrecv and not is_last_stage: grad_output = comm.sendrecv(output, next_rank, self.group) else: grad_output = None # forward 结束之后可以释放掉 ouputs # 我们只需要 ouptut.grad_fn 来做 backward if not return_outputs or not is_last_stage: free_outputs(output) ##################### # step 2: 1B1F ##################### num_1B1F = self.chunks - num_warmup for i in range(num_1B1F): # backward with sync_backward(self.module, False, output_chunks[i]): grad_input = self._backward_chunk(output_chunks[i], grad_output, input_chunks[i], losses[i]) input_chunks[i] = None if not return_outputs or not is_last_stage: output_chunks[i] = None if is_last_stage: losses[i] = losses[i].detach() # send grad, recv input if rank != 0: input = comm.sendrecv(grad_input, prev_rank, self.group) input_chunks.append(input) else: input = input_chunks[i + num_warmup] # forward with sync_forward(self.module, False): output, loss = self._forward_chunk(input, criterion, labels[i + num_warmup]) output_chunks.append(output) losses.append(loss) # send output, recv grad if not is_last_stage: grad_output = comm.sendrecv(output, next_rank, self.group) else: grad_output = None if not return_outputs or not is_last_stage: free_outputs(output) ##################### # step 3: cooldown ##################### num_cooldown = num_warmup for i in range(num_cooldown): do_recv = (no_1B1F or i != 0) if do_recv and not is_last_stage: grad_output = comm.recv(next_rank, self.group) chunk_idx = num_1B1F + i # 最后一个 backward 需要 allreduce gradients with sync_backward(self.module, chunk_idx == self.chunks - 1, output_chunks[chunk_idx]): grad_input = self._backward_chunk(output_chunks[chunk_idx], grad_output, input_chunks[chunk_idx], losses[chunk_idx]) input_chunks[chunk_idx] = None if not return_outputs or not is_last_stage: output_chunks[chunk_idx] = None if is_last_stage: losses[i] = losses[i].detach() if not is_first_stage: self.send_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.send_stream): comm.send(grad_input, prev_rank, self.group) [x.data.record_stream(self.send_stream) for x in grad_input] loss, output = None, None if is_last_stage: loss = torch.stack(losses).float().sum().detach() if return_outputs: output = microbatch.gather(output_chunks, self.batch_dim) if len(output) == 1: output = output[0] return loss, output
@cuda_timer.record("forward_chunk") def _forward_chunk(self, input, criterion=None, label=()): rank = self.group.rank() nranks = self.group.size() is_last_stage = (rank == nranks - 1) if self.checkpoint and torch.is_grad_enabled() and (not is_last_stage): output = checkpoint(self.module, *input) else: output = self.module(*input) output = [output] if isinstance(output, torch.Tensor) else output loss = None if is_last_stage and criterion is not None: loss = criterion(*output, *label) return output, loss @cuda_timer.record("backward_chunk") def _backward_chunk(self, output, grad_output, input, loss): rank = self.group.rank() nranks = self.group.size() if rank == nranks - 1: assert loss is not None loss.backward() else: # 过滤掉 None arr = [(o, g) for o, g in zip(output, grad_output) if g is not None] output, grad_output = list(zip(*arr)) # output.data 可能已经被释放了,不能直接调用 torch.autograd.backward # run_backward 的时候只需要 output.grad_fn if len(output) > 0: run_backward(output, grad_output) grad_input = [x.grad for x in input] return grad_input
def free_outputs(tensors): for x in tensors: x.data = torch.empty(1, device=x.device, dtype=x.dtype)