Shortcuts

Source code for haiscale.pipeline.interleaved

from collections import defaultdict
import torch
import torch.nn as nn
import torch.distributed as dist

from . import comm, microbatch
from .utils import checkpoint, sync_forward, sync_backward, run_backward, equal_across_ranks
from ..timer import cuda_timer


"""

nranks = 4
num_model_chunks = 2


    stage 0:  F0 F1 F2 F3 F4 F5 F6 F7 F8 F9          FA-B0          B1 B2 B3 B4 B5 B6 B7 B8 B9 BA
    stage 1:     F0 F1 F2 F3 F4 F5 F6 F7       F8-B0 F9-B1 FA-B2       B3 B4 B5 B6 B7 B8 B9 BA
    stage 2:        F0 F1 F2 F3 F4 F5    F6-B0 F7-B1 F8-B2 F9-B3 FA-B4    B5 B6 B7 B8 B9 BA
    stage 3:           F0 F1 F2 F3 F4-B0 F5-B1 F6-B2 F7-B3 F8-B4 F9-B5 FA-B6 B7 B8 B9 BA
              <------------------> <---------------------------------------> <------------------>
                    warmup                          1F1B                          cooldown


    stage 0:  F0 F1 F2 F3 F4 F5 F6 F7 F8 F9          FA-B0 B1    B2 B3    B4 B5    B6 B7 B8 B9 BA BB
    stage 1:     F0 F1 F2 F3 F4 F5 F6 F7       F8-B0 F9-B1 FA-B2 B3    B4 B5    B6 B7 B8 B9 BA BB
    stage 2:        F0 F1 F2 F3 F4 F5    F6-B0 F7-B1 F8-B2 F9-B3 FA-B4 B5    B6 B7 B8 B9 BA BB
    stage 3:           F0 F1 F2 F3 F4-B0 F5-B1 F6-B2 F7-B3 F8-B4 F9-B5 FA-B6 B7 B8 B9 BA BB
              <------------------> <---------------------------------------> <---------------------->
                    warmup                          1F1B                            cooldown


1. warmup

      F 开始前 recv(input, prev_rank)
      F 结束后 send(output, next_rank),不包括最后一个 model chunk 或者最后一个 F

2. 1F1B

      1F1B 开始前:
            第一个 1F1B: recv(input, prev_rank), sendrecv(output, grad_output, next_rank)
            其他: sendrecv(grad_input, input, prev_rank), sendrecv(output, grad_output, next_rank)

      最后一个 1F1B 结束后:
            send(grad_input, prev_rank), sendrecv(output, grad_output, next_rank)

3. cooldown

      B 开始前 recv(grad_output, next_rank)
      B 结束后 send(grad_input, prev_rank)



num_chunks % (num_model_chunks * nranks) == 0

"""


[docs]class Interleaved1F1B(nn.Module): """ 分布式流水线并行算法 interleaved 1F1B Interleaved1F1B 需要把模型切分成 ``world_size * num_model_chunks`` 个子模型,每块 GPU 上有 ``num_model_chunks`` 个子模型; 其中 ``num_model_chunks`` 取决于用户的设置。 传入 Interleaved1F1B 的模型可以有多个输入和输出,但相邻 stage 模型之间的输入输出要能对应上。 Args: modules (List[nn.Module]): 切分后的模型,可以通过 :func:`partition` 分割得到 chunks (int): microbatch 的数量 process_group (dist.ProcessGroup): 流水线并行的进程组,默认使用全局的进程组 batch_dim (int): batch 的维度,默认是 ``0`` checkpoint (bool): 是否使用 activation checkpoint,默认是 ``False`` Example: 训练的时候需要调用 ``forward_backward`` 并传入损失函数 ``criterion`` 和标签 ``labels``: .. code-block:: python from haiscale.pipeline import Interleaved1F1B, partition dist.init_process_group(...) torch.cuda.set_device(local_rank) rank, world_size = dist.get_rank(), dist.get_world_size() torch.manual_seed(12345) def loss_fn(out, y): return ((out - y)**2).sum() model = nn.Sequential(...).cuda() modules = partition(model, rank, world_size, num_model_chunks=2) # len(modules) = 2 model = Interleaved1F1B(modules, chunks=32) for x, y in dataloader: loss, _ = model.forward_backward(x, criterion=loss_fn, labels=(y,)) # eval with torch.no_grad(): out = model(x) if rank == world_size - 1: # calculate metrics ... NOTE: 示例中的第一个 stage 和最后一个 stage 的 dataloader 应该要能产生相同的数据。 NOTE: 如果要和 DDP 一起使用,需要把分割后的模型先传入 DDP,然后再把 DDP 模型传入 GPipe。 通过 :func:`make_subgroups` 可以设置 pipeline parallel 和 data parallel 的大小。 """ def __init__(self, modules, chunks, process_group=None, batch_dim=0, checkpoint=False): super().__init__() assert isinstance(modules, (list, tuple)) assert len(modules) >= 2 assert all(isinstance(m, nn.Module) for m in modules) assert isinstance(chunks, int) assert isinstance(batch_dim, int) assert isinstance(checkpoint, bool) group = process_group assert group is None or isinstance(group, dist.ProcessGroup) group = group or dist.distributed_c10d._get_default_group() assert group.size() > 1 device = torch.cuda.current_device() device = torch.device(device) assert next(modules[0].parameters()).device == device self.device = device self.module = nn.ModuleList(modules) self.chunks = chunks self.group = group self.batch_dim = batch_dim self.model_chunks = len(modules) assert self.model_chunks > 1 self.checkpoint = checkpoint nranks = group.size() nstages = self.model_chunks * nranks assert self.chunks % nranks == 0 assert self.chunks > (nstages + nranks - 2) assert equal_across_ranks(batch_dim, group) assert equal_across_ranks(chunks, group) assert equal_across_ranks(self.model_chunks, group)
[docs] def forward(self, *inputs): """ 在 ``torch.no_grad()`` 下做一次 forward Args: *inputs: 模型的输入;如果不是第一个 stage,可以不用传入 Returns: 最后一个 stage 会返回模型的输出,其他 stage 返回 ``None`` """ assert not torch.is_grad_enabled(), "forward 只有在 no_grad 模式的时候可以用" rank = self.group.rank() nranks = self.group.size() next_rank = (rank + 1) % nranks prev_rank = (rank + nranks - 1) % nranks nstages = self.model_chunks * nranks for i in range(self.model_chunks): stage = i * nranks + rank if stage != 0: inputs = comm.recv(prev_rank, self.group) outputs, _ = self._forward_chunk(stage, i, inputs) if stage != nstages - 1: comm.send(outputs, next_rank, self.group) outputs = None if isinstance(outputs, (tuple, list)) and len(outputs) == 1: outputs = outputs[0] return outputs
[docs] def forward_backward(self, *inputs, criterion=None, labels=(), return_outputs=False): """ 做一次 forward 和 backward,返回每个 microbatch 的 loss Args: *inputs: 模型的输入;如果不是第一个 stage,可以不用传入 criterion (Callable): 损失函数,通过 ``criterion(*outputs, *labels)`` 的方式调用;如果不是最后一个 stage,可以不用传入 labels (tuple): 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入 return_outputs (bool): 是否在最后一个 stage 返回模型的输出,设置为 ``True`` 的时候会多占用一部分显存;默认是 ``False`` Returns: tuple: ``(loss, outputs)`` NOTE: 如果设置 ``return_outputs = True``,``outputs`` 是模型最后一个 stage 的输出,否则为 ``None``。 只有最后一个 stage 会有返回值,其他 stage 上会返回 ``(None, None)``。 """ rank = self.group.rank() nranks = self.group.size() nstages = self.model_chunks * nranks if rank == nranks - 1: assert criterion is not None input_chunks, output_chunks, losses = defaultdict(list), defaultdict(list), defaultdict(list) if rank == 0: input_chunks[0] = microbatch.scatter(inputs, self.chunks, self.batch_dim) if rank == nranks - 1: labels = microbatch.scatter(labels, self.chunks, self.batch_dim) else: labels = [() for _ in range(self.chunks)] ##################### # step 1: warmup ##################### num_warmup = (self.model_chunks - 1) * nranks + 2 * (nranks - rank - 1) output = None for i in range(num_warmup): model_id, stage, microbatch_id = self._get_forward_id(i) recv_input = (stage != 0) input, _ = self._communicate(recv_input=recv_input, recv_grad_output=False, output=output) # release output if not return_outputs or stage != nstages - 1: free_outputs(output) # recv input assert input is not None or stage == 0 if stage != 0: input_chunks[model_id].append(input) else: input = input_chunks[model_id][microbatch_id] # prepare label label = () if stage == nstages - 1: label = labels[microbatch_id] with sync_forward(self.module[model_id], microbatch_id): output, loss = self._forward_chunk(stage, model_id, input, criterion, label) output_chunks[model_id].append(output) losses[model_id].append(loss) output = None if stage == nstages - 1 else output ##################### # step 2: 1F1B ##################### num_1F1B = self.model_chunks * self.chunks - num_warmup grad_input = None for i in range(num_1F1B): model_id, stage, microbatch_id = self._get_forward_id(i + num_warmup) bwd_stage = self._get_backward_id(i)[1] recv_input = (stage != 0) recv_grad_output = (bwd_stage != nstages - 1) input, grad_output = self._communicate(recv_input=recv_input, recv_grad_output=recv_grad_output, grad_input=grad_input, output=output) # release output if not return_outputs or stage != nstages - 1: free_outputs(output) assert input is not None or stage == 0 if stage != 0: input_chunks[model_id].append(input) else: input = input_chunks[model_id][microbatch_id] # prepare label label = () if stage == nstages - 1: label = labels[microbatch_id] # forward with sync_forward(self.module[model_id], microbatch_id): output, loss = self._forward_chunk(stage, model_id, input, criterion, label) output_chunks[model_id].append(output) losses[model_id].append(loss) output = None if stage == nstages - 1 else output # backward model_id, stage, microbatch_id = self._get_backward_id(i) bwd_input = input_chunks[model_id][microbatch_id] bwd_output = output_chunks[model_id][microbatch_id] bwd_loss = losses[model_id][microbatch_id] with sync_backward(self.module[model_id], microbatch_id == self.chunks - 1, bwd_output): grad_input = self._backward_chunk(stage, bwd_output, grad_output, bwd_input, bwd_loss) if bwd_loss is not None: losses[model_id][microbatch_id] = bwd_loss.detach() input_chunks[model_id][microbatch_id] = None if not return_outputs or stage != nstages - 1: output_chunks[model_id][microbatch_id] = None ##################### # step 3: cooldown ##################### num_cooldown = num_warmup for i in range(num_cooldown): model_id, stage, microbatch_id = self._get_backward_id(i + num_1F1B) recv_grad_output = (stage != nstages - 1) _, grad_output = self._communicate(recv_input=False, recv_grad_output=recv_grad_output, grad_input=grad_input, output=output) output = None bwd_output = output_chunks[model_id][microbatch_id] loss = losses[model_id][microbatch_id] input = input_chunks[model_id][microbatch_id] with sync_backward(self.module[model_id], microbatch_id == self.chunks - 1, bwd_output): grad_input = self._backward_chunk(stage, bwd_output, grad_output, input, loss) if loss is not None: losses[model_id][microbatch_id] = loss.detach() input_chunks[model_id][microbatch_id] = None if not return_outputs or stage != nstages - 1: output_chunks[model_id][microbatch_id] = None self._communicate(recv_input=False, recv_grad_output=False, grad_input=grad_input) loss, output = None, None if rank == nranks - 1: loss = torch.stack(losses[self.model_chunks - 1], 0).sum().detach() if return_outputs: output = microbatch.gather(output_chunks[self.model_chunks - 1], self.batch_dim) if len(output) == 1: output = output[0] return loss, output
def _communicate( self, recv_input=True, recv_grad_output=True, grad_input=None, output=None, ): rank = self.group.rank() nranks = self.group.size() next_rank = (rank + 1) % nranks prev_rank = (rank + nranks - 1) % nranks input, grad_output = comm.sendrecv_twice( grad_input, recv_input, prev_rank, output, recv_grad_output, next_rank, self.group) return input, grad_output def _get_forward_id(self, i): """ nranks = 4, num_model_chunks = 2 model_id: 0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1 ... microbatch_id: 0 1 2 3 0 1 2 3 4 5 6 7 4 5 6 7 ... stage = model_id * nranks + rank """ rank = self.group.rank() nranks = self.group.size() model_id = (i // nranks) % self.model_chunks stage = model_id * nranks + rank k, m = divmod(i, nranks * self.model_chunks) microbatch_id = k * nranks + (m % nranks) return model_id, stage, microbatch_id def _get_backward_id(self, i): """ nranks = 4, num_model_chunks = 2 model_id: 1 1 1 1 0 0 0 0 1 1 1 1 0 0 0 0 ... microbatch_id: 0 1 2 3 0 1 2 3 4 5 6 7 4 5 6 7 ... stage = model_id * nranks + rank """ rank = self.group.rank() nranks = self.group.size() model_id = (i // nranks) % self.model_chunks model_id = self.model_chunks - 1 - model_id stage = model_id * nranks + rank # model_id -> stage k, m = divmod(i, nranks * self.model_chunks) microbatch_id = k * nranks + (m % nranks) return model_id, stage, microbatch_id @cuda_timer.record("forward_chunk") def _forward_chunk(self, stage, model_id, input, criterion=None, label=()): nstages = self.model_chunks * self.group.size() if self.checkpoint and torch.is_grad_enabled(): output = checkpoint(self.module[model_id], *input) else: output = self.module[model_id](*input) output = [output] if isinstance(output, torch.Tensor) else output loss = None if stage == nstages - 1 and criterion is not None: loss = criterion(*output, *label) return output, loss @cuda_timer.record("backward_chunk") def _backward_chunk(self, stage, output, grad_output, input, loss): nstages = self.model_chunks * self.group.size() if stage == nstages - 1: loss.backward() else: # 过滤掉 None arr = [(o, g) for o, g in zip(output, grad_output) if g is not None] output, grad_output = list(zip(*arr)) # output.data 可能已经被释放了,不能直接调用 torch.autograd.backward # run_backward 的时候只需要 output.grad_fn if len(output) > 0: run_backward(output, grad_output) grad_input = None if stage != 0: grad_input = [x.grad for x in input] return grad_input
def free_outputs(tensors): if tensors is None: return for x in tensors: x.data = torch.empty(1, device=x.device, dtype=x.dtype)