Shortcuts

Source code for haiscale.pipeline.partition

from collections.abc import Iterable
import torch
import torch.nn as nn


[docs]class SequentialModel(nn.Sequential): """ 功能和 ``torch.nn.Sequential`` 类似,但支持多个输入输出 第 ``i`` 个子模型的输出数量要和第 ``i + 1`` 个子模型的输入数量相同 Example: .. code-block:: python import torch.nn as nn from haiscale.pipeline import SequentialModel class MyLayer(nn.Module): def __init__(self, index) -> None: super().__init__() self.index = index def forward(self, x, y): return x + y, x - y def __repr__(self): return f"MyLayer{self.index}()" model = SequentialModel(*[MyLayer(i) for i in range(5)]) x = torch.ones(5) y = torch.zeros(5) print(model(x, y)) """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, *inputs): nlayers = len(self) for i in range(nlayers): if isinstance(inputs, torch.Tensor): inputs = (inputs,) inputs = self[i](*inputs) return inputs
[docs]def partition(module, rank, nranks, balance=None, num_model_chunks=None): """ 切分一个模型为 ``nranks`` 个子模型,并返回其中的第 ``rank`` 个子模型 模型可以有多个输入和输出 Args: module (Iterable[torch.nn.Module]): 要被切分的模型 rank (int): 流水线并行中当前进程的 rank 编号 nranks (int): 流水线并行的进程数量 balance: 指定每个 rank 所占的层数 num_model_chunks: 每个 rank 的子模型数量(在 Interleaved1F1B 中使用); 如果设置了本参数,函数会返回一个长度为 ``num_model_chunks`` 的列表 Returns: 返回切分后的子模型 Example: .. code-block:: python import torch.nn as nn from haiscale.pipeline import partition class MyLayer(nn.Module): def __init__(self, index) -> None: super().__init__() self.index = index def forward(self, x, y): return x + y, x - y def __repr__(self): return f"MyLayer{self.index}()" module = nn.ModuleList([MyLayer(i) for i in range(5)]) module1 = partition(module, rank=1, nranks=2) print(module1) # SequentialModel( # (0): MyLayer3() # (1): MyLayer4() # ) module2 = partition(module, rank=1, nranks=2, balance=[2, 3]) print(module2) # SequentialModel( # (0): MyLayer2() # (1): MyLayer3() # (2): MyLayer4() # ) x = torch.ones(5) y = torch.zeros(5) print(module1(x, y)) # (tensor([2., 2., 2., 2., 2.]), tensor([0., 0., 0., 0., 0.])) print(module2(x, y)) # (tensor([2., 2., 2., 2., 2.]), tensor([2., 2., 2., 2., 2.])) """ assert isinstance(module, Iterable) assert all(isinstance(x, nn.Module) for x in module) chunks = 1 if num_model_chunks is None else num_model_chunks module = list(module) nlayers = len(module) nstages = nranks * chunks submodules = [] if balance is not None: assert isinstance(balance, (tuple, list)) assert all(isinstance(x, int) for x in balance) assert sum(balance) == nlayers assert len(balance) == nstages for i in range(chunks): stage = rank + nranks * i if balance is not None: start = sum(balance[:stage]) end = start + balance[stage] else: start, end = split(nlayers, nstages, stage) assert start < end, f"RANK {rank}, start {start}, end {end}, nlayers {nlayers}" submodules.append(SequentialModel(*module[start:end])) if num_model_chunks is None: return submodules[0] return submodules
def split(tot, n, i): k, m = divmod(tot, n) start = i * k + min(i, m) end = (i + 1) * k + min(i + 1, m) return start, end