Shortcuts

haiscale.pipeline

PipeDream

分布式流水线并行算法 PipeDream (non-interleaved 1F1B)

GPipe

分布式流水线并行算法 GPipe

Interleaved1F1B

分布式流水线并行算法 interleaved 1F1B

SequentialModel

功能和 torch.nn.Sequential 类似,但支持多个输入输出

partition

切分一个模型为 nranks 个子模型,并返回其中的第 rank 个子模型

make_subgroups

划分 pipeline parallel 和 data parallel 的进程组

class haiscale.pipeline.PipeDream(module, chunks, process_group=None, batch_dim=0, checkpoint=False)[source]

分布式流水线并行算法 PipeDream (non-interleaved 1F1B)

论文:《 PipeDream: Fast and Efficient Pipeline Parallel DNN Training 》

stage 0:  F0 F1 F2 F3          B0 F4 B1 F5          B2 B3 B4 B5
stage 1:     F0 F1 F2       B0 F3 B1 F4 B2 F5       B3 B4 B5
stage 2:        F0 F1    B0 F2 B1 F3 B2 F4 B3 F5    B4 B5
stage 3:           F0 B0 F1 B1 F2 B2 F3 B3 F4 B4 F5 B5
          <---------> <---------------------------> <--------->
             warmup               1B1F                cooldown

相比于 GPipe, 两者的训练速度差不多,但 PipeDream 能够减少显存的使用,显存占用是 GPipe 的 min(1, ngpus / chunks), 因此推荐优先使用 PipeDream。

传入 PipeDream 的模型可以有多个输入和输出,但相邻 stage 模型之间的输入输出要能对应上。

Parameters
  • module (nn.Module) – 模型的一部分,可以通过 partition() 分割得到

  • chunks (int) – microbatch 的数量

  • process_group (dist.ProcessGroup) – 流水线并行的进程组,默认使用全局的进程组

  • batch_dim (int) – batch 的维度,默认是 0

  • checkpoint (bool) – 是否使用 activation checkpoint,默认是 False

Example:

训练的时候需要调用 forward_backward 并传入损失函数 criterion 和标签 labels:

from haiscale.pipeline import PipeDream, partition

dist.init_process_group(...)
torch.cuda.set_device(local_rank)
rank, world_size = dist.get_rank(), dist.get_world_size()
torch.manual_seed(12345)

def loss_fn(out, y):
    return ((out - y)**2).sum()

model = nn.Sequential(...).cuda()
model = partition(model, rank, world_size)
model = PipeDream(model, chunks=32)

for x, y in dataloader:
    loss, _ = model.forward_backward(x, criterion=loss_fn, labels=(y,))

# eval
with torch.no_grad():
    out = model(x)
    if rank == world_size - 1:
        # calculate metrics ...

Note

示例中的第一个 stage 和最后一个 stage 的 dataloader 应该要能产生相同的数据。

Note

如果要和 DDP 一起使用,需要把分割后的模型先传入 DDP,然后再把 DDP 模型传入 GPipe。 通过 make_subgroups() 可以设置 pipeline parallel 和 data parallel 的大小。

forward(*inputs, chunks=None, criterion=None, labels=(), return_outputs=False)[source]

torch.no_grad() 下做一次 forward

Parameters
  • *inputs – 模型的输入;如果不是第一个 stage,可以不用传入

  • criterion (Callable) – 损失函数,通过 criterion(*outputs, *labels) 的方式调用;如果不是最后一个 stage,可以不用传入

  • labels (tuple) – 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入

  • return_outputs (bool) – 是否在最后一个 stage 返回模型的输出,设置为 True 的时候会多占用一部分显存;默认是 False

Returns

最后一个 stage 会返回 (loss, outputs),其他 stage 返回 (None, None)

forward_backward(*inputs, criterion=None, labels=(), return_outputs=False)[source]

做一次 forward 和 backward,返回 loss

Parameters
  • *inputs – 模型的输入;如果不是第一个 stage,可以不用传入

  • criterion (Callable) – 损失函数,通过 criterion(*outputs, *labels) 的方式调用;如果不是最后一个 stage,可以不用传入

  • labels (tuple) – 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入

  • return_outputs (bool) – 是否在最后一个 stage 返回模型的输出,设置为 True 的时候会多占用一部分显存;默认是 False

Returns

(loss, outputs)

Return type

tuple

Note

如果设置 return_outputs = Trueoutputs 是模型最后一个 stage 的输出,否则为 None。 只有最后一个 stage 会有返回值,其他 stage 上会返回 (None, None)

num_stages()[source]

返回 stage 的数量

stage()[source]

返回当前的 stage

class haiscale.pipeline.Interleaved1F1B(modules, chunks, process_group=None, batch_dim=0, checkpoint=False)[source]

分布式流水线并行算法 interleaved 1F1B

Interleaved1F1B 需要把模型切分成 world_size * num_model_chunks 个子模型,每块 GPU 上有 num_model_chunks 个子模型; 其中 num_model_chunks 取决于用户的设置。

传入 Interleaved1F1B 的模型可以有多个输入和输出,但相邻 stage 模型之间的输入输出要能对应上。

Parameters
  • modules (List[nn.Module]) – 切分后的模型,可以通过 partition() 分割得到

  • chunks (int) – microbatch 的数量

  • process_group (dist.ProcessGroup) – 流水线并行的进程组,默认使用全局的进程组

  • batch_dim (int) – batch 的维度,默认是 0

  • checkpoint (bool) – 是否使用 activation checkpoint,默认是 False

Example:

训练的时候需要调用 forward_backward 并传入损失函数 criterion 和标签 labels:

from haiscale.pipeline import Interleaved1F1B, partition

dist.init_process_group(...)
torch.cuda.set_device(local_rank)
rank, world_size = dist.get_rank(), dist.get_world_size()
torch.manual_seed(12345)

def loss_fn(out, y):
    return ((out - y)**2).sum()

model = nn.Sequential(...).cuda()
modules = partition(model, rank, world_size, num_model_chunks=2)  # len(modules) = 2
model = Interleaved1F1B(modules, chunks=32)

for x, y in dataloader:
    loss, _ = model.forward_backward(x, criterion=loss_fn, labels=(y,))

# eval
with torch.no_grad():
    out = model(x)
    if rank == world_size - 1:
        # calculate metrics ...

Note

示例中的第一个 stage 和最后一个 stage 的 dataloader 应该要能产生相同的数据。

Note

如果要和 DDP 一起使用,需要把分割后的模型先传入 DDP,然后再把 DDP 模型传入 GPipe。 通过 make_subgroups() 可以设置 pipeline parallel 和 data parallel 的大小。

forward(*inputs)[source]

torch.no_grad() 下做一次 forward

Parameters

*inputs – 模型的输入;如果不是第一个 stage,可以不用传入

Returns

最后一个 stage 会返回模型的输出,其他 stage 返回 None

forward_backward(*inputs, criterion=None, labels=(), return_outputs=False)[source]

做一次 forward 和 backward,返回每个 microbatch 的 loss

Parameters
  • *inputs – 模型的输入;如果不是第一个 stage,可以不用传入

  • criterion (Callable) – 损失函数,通过 criterion(*outputs, *labels) 的方式调用;如果不是最后一个 stage,可以不用传入

  • labels (tuple) – 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入

  • return_outputs (bool) – 是否在最后一个 stage 返回模型的输出,设置为 True 的时候会多占用一部分显存;默认是 False

Returns

(loss, outputs)

Return type

tuple

Note

如果设置 return_outputs = Trueoutputs 是模型最后一个 stage 的输出,否则为 None。 只有最后一个 stage 会有返回值,其他 stage 上会返回 (None, None)

class haiscale.pipeline.GPipe(module, chunks, process_group=None, batch_dim=0, checkpoint='never', broadcast_outputs=True)[source]

分布式流水线并行算法 GPipe

论文:《 GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism 》

stage 0:  F0 F1 F2 F3 F4 F5                   B5 B4 B3 B2 B1 B0
stage 1:     F0 F1 F2 F3 F4 F5             B5 B4 B3 B2 B1 B0
stage 2:        F0 F1 F2 F3 F4 F5       B5 B4 B3 B2 B1 B0
stage 3:           F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0
          <------------------------> <------------------------>
                  forward                   backward

传入 GPipe 的模型可以有多个输入和输出,但相邻 stage 模型之间的输入输出要能对应上。

Parameters
  • module (nn.Module) – 模型的一部分,可以通过 partition() 分割得到

  • chunks (int) – microbatch 的数量

  • process_group (dist.ProcessGroup) – 流水线并行的进程组,默认使用全局的进程组

  • batch_dim (int) – batch 的维度,默认是 0

  • checkpoint (str) – never, except_last 或者 always,默认是 never

  • broadcast_outputs (bool) – 是否把最后一个 stage 的输出广播到其他 stage 上,作为每个 stage 上 GPipe 实例的输出; 设置 broadcast_outputs = False 可以节省一些通讯和显存上的开销,如果最后一个 stage 的输出比较大, 推荐设置为 False;默认是 True

Example:

GPipe 有两种使用方式,第一种方式需要设置 broadcast_outputs = True, 使用上类似于 DDP,每个 stage 返回的输出是相同的, 都可以用来计算 loss,但实际上只有最后一个 stage 的 loss 会用来做 backward:

from haiscale.pipeline import GPipe, partition

dist.init_process_group(...)
torch.cuda.set_device(local_rank)
rank, world_size = dist.get_rank(), dist.get_world_size()
torch.manual_seed(12345)

model = nn.Sequential(...).cuda()
model = partition(model, rank, world_size)
model = GPipe(model, chunks=64)

for x, y in dataloader:
    out = model(x)
    loss = ((out - y)**2).sum()
    loss.backward()

第二种使用方法需要传入损失函数 criterion,在 forward_backward 里会通过 criterion(*outputs, *labels) 的方式调用来计算 loss:

from haiscale.pipeline import GPipe, partition

dist.init_process_group(...)
torch.cuda.set_device(local_rank)
rank, world_size = dist.get_rank(), dist.get_world_size()
torch.manual_seed(12345)

def loss_fn(out, y):
    return ((out - y)**2).sum()

model = nn.Sequential(...).cuda()
model = partition(model, rank, world_size)

model = GPipe(model, chunks=32)

for x, y in dataloader:
    loss, _ = model.forward_backward(x, criterion=loss_fn, labels=(y,))

Note

示例中的第一个 stage 和最后一个 stage 的 dataloader 应该要能产生相同的数据。

Note

如果要和 DDP 一起使用,需要把分割后的模型先传入 DDP,然后再把 DDP 模型传入 GPipe。 通过 make_subgroups() 可以设置 pipeline parallel 和 data parallel 的大小。

forward_backward(*inputs, criterion=None, labels=(), return_outputs=False)[source]

做一次 forward 和 backward,返回每个 microbatch 的 loss

Parameters
  • *inputs – 模型的输入;如果不是第一个 stage,可以不用传入

  • criterion (Callable) – 损失函数,通过 criterion(*outputs, *labels) 的方式调用;如果不是最后一个 stage,可以不用传入

  • labels (tuple) – 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入

  • return_outputs (bool) – 是否在最后一个 stage 返回模型的输出,设置为 True 的时候会多占用一部分显存;默认是 False

Returns

(loss, outputs)

Return type

tuple

Note

如果设置 return_outputs = Trueoutputs 是模型最后一个 stage 的输出,否则为 None。 只有最后一个 stage 会有返回值,其他 stage 上会返回 (None, None)

num_stages()[source]

返回 stage 的数量

stage()[source]

返回当前的 stage

class haiscale.pipeline.SequentialModel(*args, **kwargs)[source]

功能和 torch.nn.Sequential 类似,但支持多个输入输出

i 个子模型的输出数量要和第 i + 1 个子模型的输入数量相同

Example:

import torch.nn as nn
from haiscale.pipeline import SequentialModel

class MyLayer(nn.Module):

    def __init__(self, index) -> None:
        super().__init__()
        self.index = index

    def forward(self, x, y):
        return x + y, x - y

    def __repr__(self):
        return f"MyLayer{self.index}()"

model = SequentialModel(*[MyLayer(i) for i in range(5)])
x = torch.ones(5)
y = torch.zeros(5)
print(model(x, y))
haiscale.pipeline.partition(module, rank, nranks, balance=None, num_model_chunks=None)[source]

切分一个模型为 nranks 个子模型,并返回其中的第 rank 个子模型

模型可以有多个输入和输出

Parameters
  • module (Iterable[torch.nn.Module]) – 要被切分的模型

  • rank (int) – 流水线并行中当前进程的 rank 编号

  • nranks (int) – 流水线并行的进程数量

  • balance – 指定每个 rank 所占的层数

  • num_model_chunks – 每个 rank 的子模型数量(在 Interleaved1F1B 中使用); 如果设置了本参数,函数会返回一个长度为 num_model_chunks 的列表

Returns

返回切分后的子模型

Example:

import torch.nn as nn
from haiscale.pipeline import partition

class MyLayer(nn.Module):

    def __init__(self, index) -> None:
        super().__init__()
        self.index = index

    def forward(self, x, y):
        return x + y, x - y

    def __repr__(self):
        return f"MyLayer{self.index}()"

module = nn.ModuleList([MyLayer(i) for i in range(5)])
module1 = partition(module, rank=1, nranks=2)
print(module1)
# SequentialModel(
#   (0): MyLayer3()
#   (1): MyLayer4()
# )

module2 = partition(module, rank=1, nranks=2, balance=[2, 3])
print(module2)
# SequentialModel(
#   (0): MyLayer2()
#   (1): MyLayer3()
#   (2): MyLayer4()
# )

x = torch.ones(5)
y = torch.zeros(5)
print(module1(x, y))
# (tensor([2., 2., 2., 2., 2.]), tensor([0., 0., 0., 0., 0.]))
print(module2(x, y))
# (tensor([2., 2., 2., 2., 2.]), tensor([2., 2., 2., 2., 2.]))
haiscale.pipeline.make_subgroups(pp_size, dp_size=None, group=None)[source]

划分 pipeline parallel 和 data parallel 的进程组

Parameters
  • pp_size (int) – pipeline parallel 的大小

  • dp_size (int) – data parallel 的大小,默认是 group.size() / pp_size

  • group (dist.ProcessGroup) – 进程组,默认使用全局的进程组

Returns

(dp_group, pp_group)

Return type

tuple

Examples:

from haiscale.pipeline import GPipe, partition, make_subgroups

dist.init_process_group(...)
torch.cuda.set_device(local_rank)
rank, world_size = dist.get_rank(), dist.get_world_size()

dp_group, pp_group = make_subgroups(pp_size=4)

model = nn.Sequential(...).cuda()
model = partition(model, pp_group.rank(), pp_group.size())

model = DDP(model, device_ids=[local_rank], process_group=dp_group)
model = GPipe(model, chunks=64, process_group=pp_group)

for x, y in dataloader:
    out = model(x)
    loss = ((out - y)**2).sum()
    loss.backward()