haiscale.pipeline¶
分布式流水线并行算法 PipeDream (non-interleaved 1F1B) |
|
分布式流水线并行算法 GPipe |
|
分布式流水线并行算法 interleaved 1F1B |
|
功能和 |
|
切分一个模型为 |
|
划分 pipeline parallel 和 data parallel 的进程组 |
- class haiscale.pipeline.PipeDream(module, chunks, process_group=None, batch_dim=0, checkpoint=False)[source]¶
分布式流水线并行算法 PipeDream (non-interleaved 1F1B)
论文:《 PipeDream: Fast and Efficient Pipeline Parallel DNN Training 》
stage 0: F0 F1 F2 F3 B0 F4 B1 F5 B2 B3 B4 B5 stage 1: F0 F1 F2 B0 F3 B1 F4 B2 F5 B3 B4 B5 stage 2: F0 F1 B0 F2 B1 F3 B2 F4 B3 F5 B4 B5 stage 3: F0 B0 F1 B1 F2 B2 F3 B3 F4 B4 F5 B5 <---------> <---------------------------> <---------> warmup 1B1F cooldown
相比于
GPipe
, 两者的训练速度差不多,但 PipeDream 能够减少显存的使用,显存占用是 GPipe 的min(1, ngpus / chunks)
, 因此推荐优先使用 PipeDream。传入 PipeDream 的模型可以有多个输入和输出,但相邻 stage 模型之间的输入输出要能对应上。
- Parameters
module (nn.Module) – 模型的一部分,可以通过
partition()
分割得到chunks (int) – microbatch 的数量
process_group (dist.ProcessGroup) – 流水线并行的进程组,默认使用全局的进程组
batch_dim (int) – batch 的维度,默认是
0
checkpoint (bool) – 是否使用 activation checkpoint,默认是
False
Example:
训练的时候需要调用
forward_backward
并传入损失函数criterion
和标签labels
:from haiscale.pipeline import PipeDream, partition dist.init_process_group(...) torch.cuda.set_device(local_rank) rank, world_size = dist.get_rank(), dist.get_world_size() torch.manual_seed(12345) def loss_fn(out, y): return ((out - y)**2).sum() model = nn.Sequential(...).cuda() model = partition(model, rank, world_size) model = PipeDream(model, chunks=32) for x, y in dataloader: loss, _ = model.forward_backward(x, criterion=loss_fn, labels=(y,)) # eval with torch.no_grad(): out = model(x) if rank == world_size - 1: # calculate metrics ...
Note
示例中的第一个 stage 和最后一个 stage 的 dataloader 应该要能产生相同的数据。
Note
如果要和 DDP 一起使用,需要把分割后的模型先传入 DDP,然后再把 DDP 模型传入 GPipe。 通过
make_subgroups()
可以设置 pipeline parallel 和 data parallel 的大小。- forward(*inputs, chunks=None, criterion=None, labels=(), return_outputs=False)[source]¶
在
torch.no_grad()
下做一次 forward- Parameters
*inputs – 模型的输入;如果不是第一个 stage,可以不用传入
criterion (Callable) – 损失函数,通过
criterion(*outputs, *labels)
的方式调用;如果不是最后一个 stage,可以不用传入labels (tuple) – 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入
return_outputs (bool) – 是否在最后一个 stage 返回模型的输出,设置为
True
的时候会多占用一部分显存;默认是False
- Returns
最后一个 stage 会返回
(loss, outputs)
,其他 stage 返回(None, None)
- forward_backward(*inputs, criterion=None, labels=(), return_outputs=False)[source]¶
做一次 forward 和 backward,返回 loss
- Parameters
*inputs – 模型的输入;如果不是第一个 stage,可以不用传入
criterion (Callable) – 损失函数,通过
criterion(*outputs, *labels)
的方式调用;如果不是最后一个 stage,可以不用传入labels (tuple) – 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入
return_outputs (bool) – 是否在最后一个 stage 返回模型的输出,设置为
True
的时候会多占用一部分显存;默认是False
- Returns
(loss, outputs)
- Return type
tuple
Note
如果设置
return_outputs = True
,outputs
是模型最后一个 stage 的输出,否则为None
。 只有最后一个 stage 会有返回值,其他 stage 上会返回(None, None)
。
- class haiscale.pipeline.Interleaved1F1B(modules, chunks, process_group=None, batch_dim=0, checkpoint=False)[source]¶
分布式流水线并行算法 interleaved 1F1B
Interleaved1F1B 需要把模型切分成
world_size * num_model_chunks
个子模型,每块 GPU 上有num_model_chunks
个子模型; 其中num_model_chunks
取决于用户的设置。传入 Interleaved1F1B 的模型可以有多个输入和输出,但相邻 stage 模型之间的输入输出要能对应上。
- Parameters
modules (List[nn.Module]) – 切分后的模型,可以通过
partition()
分割得到chunks (int) – microbatch 的数量
process_group (dist.ProcessGroup) – 流水线并行的进程组,默认使用全局的进程组
batch_dim (int) – batch 的维度,默认是
0
checkpoint (bool) – 是否使用 activation checkpoint,默认是
False
Example:
训练的时候需要调用
forward_backward
并传入损失函数criterion
和标签labels
:from haiscale.pipeline import Interleaved1F1B, partition dist.init_process_group(...) torch.cuda.set_device(local_rank) rank, world_size = dist.get_rank(), dist.get_world_size() torch.manual_seed(12345) def loss_fn(out, y): return ((out - y)**2).sum() model = nn.Sequential(...).cuda() modules = partition(model, rank, world_size, num_model_chunks=2) # len(modules) = 2 model = Interleaved1F1B(modules, chunks=32) for x, y in dataloader: loss, _ = model.forward_backward(x, criterion=loss_fn, labels=(y,)) # eval with torch.no_grad(): out = model(x) if rank == world_size - 1: # calculate metrics ...
Note
示例中的第一个 stage 和最后一个 stage 的 dataloader 应该要能产生相同的数据。
Note
如果要和 DDP 一起使用,需要把分割后的模型先传入 DDP,然后再把 DDP 模型传入 GPipe。 通过
make_subgroups()
可以设置 pipeline parallel 和 data parallel 的大小。- forward(*inputs)[source]¶
在
torch.no_grad()
下做一次 forward- Parameters
*inputs – 模型的输入;如果不是第一个 stage,可以不用传入
- Returns
最后一个 stage 会返回模型的输出,其他 stage 返回
None
- forward_backward(*inputs, criterion=None, labels=(), return_outputs=False)[source]¶
做一次 forward 和 backward,返回每个 microbatch 的 loss
- Parameters
*inputs – 模型的输入;如果不是第一个 stage,可以不用传入
criterion (Callable) – 损失函数,通过
criterion(*outputs, *labels)
的方式调用;如果不是最后一个 stage,可以不用传入labels (tuple) – 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入
return_outputs (bool) – 是否在最后一个 stage 返回模型的输出,设置为
True
的时候会多占用一部分显存;默认是False
- Returns
(loss, outputs)
- Return type
tuple
Note
如果设置
return_outputs = True
,outputs
是模型最后一个 stage 的输出,否则为None
。 只有最后一个 stage 会有返回值,其他 stage 上会返回(None, None)
。
- class haiscale.pipeline.GPipe(module, chunks, process_group=None, batch_dim=0, checkpoint='never', broadcast_outputs=True)[source]¶
分布式流水线并行算法 GPipe
论文:《 GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism 》
stage 0: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0 stage 1: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0 stage 2: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0 stage 3: F0 F1 F2 F3 F4 F5 B5 B4 B3 B2 B1 B0 <------------------------> <------------------------> forward backward
传入 GPipe 的模型可以有多个输入和输出,但相邻 stage 模型之间的输入输出要能对应上。
- Parameters
module (nn.Module) – 模型的一部分,可以通过
partition()
分割得到chunks (int) – microbatch 的数量
process_group (dist.ProcessGroup) – 流水线并行的进程组,默认使用全局的进程组
batch_dim (int) – batch 的维度,默认是
0
checkpoint (str) –
never
,except_last
或者always
,默认是never
broadcast_outputs (bool) – 是否把最后一个 stage 的输出广播到其他 stage 上,作为每个 stage 上 GPipe 实例的输出; 设置
broadcast_outputs = False
可以节省一些通讯和显存上的开销,如果最后一个 stage 的输出比较大, 推荐设置为False
;默认是True
Example:
GPipe 有两种使用方式,第一种方式需要设置
broadcast_outputs = True
, 使用上类似于 DDP,每个 stage 返回的输出是相同的, 都可以用来计算 loss,但实际上只有最后一个 stage 的 loss 会用来做 backward:from haiscale.pipeline import GPipe, partition dist.init_process_group(...) torch.cuda.set_device(local_rank) rank, world_size = dist.get_rank(), dist.get_world_size() torch.manual_seed(12345) model = nn.Sequential(...).cuda() model = partition(model, rank, world_size) model = GPipe(model, chunks=64) for x, y in dataloader: out = model(x) loss = ((out - y)**2).sum() loss.backward()
第二种使用方法需要传入损失函数
criterion
,在forward_backward
里会通过criterion(*outputs, *labels)
的方式调用来计算 loss:from haiscale.pipeline import GPipe, partition dist.init_process_group(...) torch.cuda.set_device(local_rank) rank, world_size = dist.get_rank(), dist.get_world_size() torch.manual_seed(12345) def loss_fn(out, y): return ((out - y)**2).sum() model = nn.Sequential(...).cuda() model = partition(model, rank, world_size) model = GPipe(model, chunks=32) for x, y in dataloader: loss, _ = model.forward_backward(x, criterion=loss_fn, labels=(y,))
Note
示例中的第一个 stage 和最后一个 stage 的 dataloader 应该要能产生相同的数据。
Note
如果要和 DDP 一起使用,需要把分割后的模型先传入 DDP,然后再把 DDP 模型传入 GPipe。 通过
make_subgroups()
可以设置 pipeline parallel 和 data parallel 的大小。- forward_backward(*inputs, criterion=None, labels=(), return_outputs=False)[source]¶
做一次 forward 和 backward,返回每个 microbatch 的 loss
- Parameters
*inputs – 模型的输入;如果不是第一个 stage,可以不用传入
criterion (Callable) – 损失函数,通过
criterion(*outputs, *labels)
的方式调用;如果不是最后一个 stage,可以不用传入labels (tuple) – 传入损失函数的标签等数据;如果不是最后一个 stage,可以不用传入
return_outputs (bool) – 是否在最后一个 stage 返回模型的输出,设置为
True
的时候会多占用一部分显存;默认是False
- Returns
(loss, outputs)
- Return type
tuple
Note
如果设置
return_outputs = True
,outputs
是模型最后一个 stage 的输出,否则为None
。 只有最后一个 stage 会有返回值,其他 stage 上会返回(None, None)
。
- class haiscale.pipeline.SequentialModel(*args, **kwargs)[source]¶
功能和
torch.nn.Sequential
类似,但支持多个输入输出第
i
个子模型的输出数量要和第i + 1
个子模型的输入数量相同Example:
import torch.nn as nn from haiscale.pipeline import SequentialModel class MyLayer(nn.Module): def __init__(self, index) -> None: super().__init__() self.index = index def forward(self, x, y): return x + y, x - y def __repr__(self): return f"MyLayer{self.index}()" model = SequentialModel(*[MyLayer(i) for i in range(5)]) x = torch.ones(5) y = torch.zeros(5) print(model(x, y))
- haiscale.pipeline.partition(module, rank, nranks, balance=None, num_model_chunks=None)[source]¶
切分一个模型为
nranks
个子模型,并返回其中的第rank
个子模型模型可以有多个输入和输出
- Parameters
module (Iterable[torch.nn.Module]) – 要被切分的模型
rank (int) – 流水线并行中当前进程的 rank 编号
nranks (int) – 流水线并行的进程数量
balance – 指定每个 rank 所占的层数
num_model_chunks – 每个 rank 的子模型数量(在 Interleaved1F1B 中使用); 如果设置了本参数,函数会返回一个长度为
num_model_chunks
的列表
- Returns
返回切分后的子模型
Example:
import torch.nn as nn from haiscale.pipeline import partition class MyLayer(nn.Module): def __init__(self, index) -> None: super().__init__() self.index = index def forward(self, x, y): return x + y, x - y def __repr__(self): return f"MyLayer{self.index}()" module = nn.ModuleList([MyLayer(i) for i in range(5)]) module1 = partition(module, rank=1, nranks=2) print(module1) # SequentialModel( # (0): MyLayer3() # (1): MyLayer4() # ) module2 = partition(module, rank=1, nranks=2, balance=[2, 3]) print(module2) # SequentialModel( # (0): MyLayer2() # (1): MyLayer3() # (2): MyLayer4() # ) x = torch.ones(5) y = torch.zeros(5) print(module1(x, y)) # (tensor([2., 2., 2., 2., 2.]), tensor([0., 0., 0., 0., 0.])) print(module2(x, y)) # (tensor([2., 2., 2., 2., 2.]), tensor([2., 2., 2., 2., 2.]))
- haiscale.pipeline.make_subgroups(pp_size, dp_size=None, group=None)[source]¶
划分 pipeline parallel 和 data parallel 的进程组
- Parameters
pp_size (int) – pipeline parallel 的大小
dp_size (int) – data parallel 的大小,默认是
group.size() / pp_size
group (dist.ProcessGroup) – 进程组,默认使用全局的进程组
- Returns
(dp_group, pp_group)
- Return type
tuple
Examples:
from haiscale.pipeline import GPipe, partition, make_subgroups dist.init_process_group(...) torch.cuda.set_device(local_rank) rank, world_size = dist.get_rank(), dist.get_world_size() dp_group, pp_group = make_subgroups(pp_size=4) model = nn.Sequential(...).cuda() model = partition(model, pp_group.rank(), pp_group.size()) model = DDP(model, device_ids=[local_rank], process_group=dp_group) model = GPipe(model, chunks=64, process_group=pp_group) for x, y in dataloader: out = model(x) loss = ((out - y)**2).sum() loss.backward()