haiscale.fsdp¶
FullyShardedDataParallel,使用方法和 PyTorch FSDP 类似。 |
- class haiscale.fsdp.FullyShardedDataParallel(module, process_group=None, auto_wrap_policy=None, reshard_after_forward=True, backward_prefetch=None, forward_prefetch=False, fsdp_group_size=None, _is_root=True, _fsdp_group=None, _ddp_group=None)[source]¶
FullyShardedDataParallel,使用方法和 PyTorch FSDP 类似。
- Parameters
module (torch.nn.Module) – 用户的 PyTorch 模型
process_group (ProcessGroup) – 全局的 process group,默认是
None
auto_wrap_policy (Callable[nn.Module, bool, int]) –
auto_wrap_policy
决定了参数如何分组,和 PyTorch FSDP 的同名参数作用一样;默认是None
。reshard_after_forward (bool) – 是否在 forward 之后切分参数;默认是
True
backward_prefetch (BackwardPrefetch) – 在 backward 的时候如何 prefetch 下一个 FSDP 的参数,如果是
None
则不会做 prefetch;默认是None
。BackwardPrefetch.BACKWARD_PRE
: 在 backward 之前 prefetch 下一个 FSDP 的参数;BackwardPrefetch.BACKWARD_POST
: 在 backward 之后 prefetch 下一个 FSDP 的参数。forward_prefetch (bool) – 在 forward 的时候是否 prefetch 下一个 FSDP 的参数,默认是
False
fsdp_group_size (int) – FSDP 一个分组的大小(all_gather 的 rank 数量),FSDP 不同分组之间做数据并行; 如果是 None 则全部 rank 都在同一个 FSDP 分组;默认是
None
Examples:
训练:
from haiscale.fsdp import FullyShardedDataParallel model = MyModel().cuda() model = FullyShardedDataParallel(model) optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) # training ... for epoch in range(epochs): for step, (x, y) in enumerate(dataloader): # training optimizer.zero_grad() output = model(x) loss_fn(y, output).backward() optimizer.step()
保存完整的 checkpoint:
model = FullyShardedDataParallel(model) with model.summon_full_params(): if rank == 0: state = model.state_dict() torch.save(state, 'model.pt') # FSDP 模型加载完整的 checkpoint state = torch.load('model.pt', map_location='cpu') with model.summon_full_params(): model.load_state_dict(state) # 非 FSDP 模型也可以直接加载 model2 = MyModel().cuda() model2.load_state_dict(state)
每个 GPU 都保存自己本地的 checkpoint(保存速度更快):
# 每个 GPU 上的 FSDP 都保存一份 checkpoint model = FullyShardedDataParallel(model) state = model.state_dict() rank = dist.get_rank() torch.save(state, f'model{rank}.pt') # 每个 GPU 上的 FSDP 都需要加载 state = torch.load(f'model{rank}.pt', map_location='cpu') model.load_state_dict(state)
- summon_full_params()[source]¶
all_gather 完整的参数,方便用户保存完整的 checkpoint。 保存后的 checkpoint 等价于 fsdp_model.module.state_dict()
Examples:
from haiscale.fsdp import FullyShardedDataParallel model = MyModel().cuda() model = FullyShardedDataParallel(model) # save checkpoint with model.summon_full_params(): if rank == 0: state = model.state_dict() torch.save(state, 'model.pt') # load checkpoint state = torch.load('model.pt', map_location='cpu') with model.summon_full_params(): model.load_state_dict(state) # this could also work model2 = MyModel().cuda() model2.load_state_dict(state)