Source code for haiscale.pipeline.partition
from collections.abc import Iterable
import torch
import torch.nn as nn
[docs]class SequentialModel(nn.Sequential):
"""
功能和 ``torch.nn.Sequential`` 类似,但支持多个输入输出
第 ``i`` 个子模型的输出数量要和第 ``i + 1`` 个子模型的输入数量相同
Example:
.. code-block:: python
import torch.nn as nn
from haiscale.pipeline import SequentialModel
class MyLayer(nn.Module):
def __init__(self, index) -> None:
super().__init__()
self.index = index
def forward(self, x, y):
return x + y, x - y
def __repr__(self):
return f"MyLayer{self.index}()"
model = SequentialModel(*[MyLayer(i) for i in range(5)])
x = torch.ones(5)
y = torch.zeros(5)
print(model(x, y))
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, *inputs):
nlayers = len(self)
for i in range(nlayers):
if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
inputs = self[i](*inputs)
return inputs
[docs]def partition(module, rank, nranks, balance=None, num_model_chunks=None):
"""
切分一个模型为 ``nranks`` 个子模型,并返回其中的第 ``rank`` 个子模型
模型可以有多个输入和输出
Args:
module (Iterable[torch.nn.Module]): 要被切分的模型
rank (int): 流水线并行中当前进程的 rank 编号
nranks (int): 流水线并行的进程数量
balance: 指定每个 rank 所占的层数
num_model_chunks: 每个 rank 的子模型数量(在 Interleaved1F1B 中使用);
如果设置了本参数,函数会返回一个长度为 ``num_model_chunks`` 的列表
Returns:
返回切分后的子模型
Example:
.. code-block:: python
import torch.nn as nn
from haiscale.pipeline import partition
class MyLayer(nn.Module):
def __init__(self, index) -> None:
super().__init__()
self.index = index
def forward(self, x, y):
return x + y, x - y
def __repr__(self):
return f"MyLayer{self.index}()"
module = nn.ModuleList([MyLayer(i) for i in range(5)])
module1 = partition(module, rank=1, nranks=2)
print(module1)
# SequentialModel(
# (0): MyLayer3()
# (1): MyLayer4()
# )
module2 = partition(module, rank=1, nranks=2, balance=[2, 3])
print(module2)
# SequentialModel(
# (0): MyLayer2()
# (1): MyLayer3()
# (2): MyLayer4()
# )
x = torch.ones(5)
y = torch.zeros(5)
print(module1(x, y))
# (tensor([2., 2., 2., 2., 2.]), tensor([0., 0., 0., 0., 0.]))
print(module2(x, y))
# (tensor([2., 2., 2., 2., 2.]), tensor([2., 2., 2., 2., 2.]))
"""
assert isinstance(module, Iterable)
assert all(isinstance(x, nn.Module) for x in module)
chunks = 1 if num_model_chunks is None else num_model_chunks
module = list(module)
nlayers = len(module)
nstages = nranks * chunks
submodules = []
if balance is not None:
assert isinstance(balance, (tuple, list))
assert all(isinstance(x, int) for x in balance)
assert sum(balance) == nlayers
assert len(balance) == nstages
for i in range(chunks):
stage = rank + nranks * i
if balance is not None:
start = sum(balance[:stage])
end = start + balance[stage]
else:
start, end = split(nlayers, nstages, stage)
assert start < end, f"RANK {rank}, start {start}, end {end}, nlayers {nlayers}"
submodules.append(SequentialModel(*module[start:end]))
if num_model_chunks is None:
return submodules[0]
return submodules
def split(tot, n, i):
k, m = divmod(tot, n)
start = i * k + min(i, m)
end = (i + 1) * k + min(i + 1, m)
return start, end