• Docs >
  • hfai.distributed.fsdp
Shortcuts

hfai.distributed.fsdp

FullyShardedDataParallel

FullyShardedDataParallel,使用方法和 PyTorch FSDP 类似。

class hfai.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, auto_wrap_policy=None, reshard_after_forward=True, backward_prefetch=None, forward_prefetch=False, fsdp_group_size=None, _is_root=True)[source]

FullyShardedDataParallel,使用方法和 PyTorch FSDP 类似。

Parameters
  • module (torch.nn.Module) – 用户的 PyTorch 模型

  • process_group (ProcessGroup) – 全局的 process group,默认是 None

  • auto_wrap_policy (Callable[nn.Module, bool, int]) – auto_wrap_policy 决定了参数如何分组,和 PyTorch FSDP 的同名参数作用一样;默认是 Nonehfai.distributed.fsdp.warp.size_based_auto_wrap_policy: 根据参数的大小分组;

  • reshard_after_forward (bool) – 是否在 forward 之后切分参数;默认是 True

  • backward_prefetch (BackwardPrefetch) – 在 backward 的时候如何 prefetch 下一个 FSDP 的参数,如果是 None 则不会做 prefetch;默认是 NoneBackwardPrefetch.BACKWARD_PRE: 在 backward 之前 prefetch 下一个 FSDP 的参数; BackwardPrefetch.BACKWARD_POST: 在 backward 之后 prefetch 下一个 FSDP 的参数。

  • forward_prefetch (bool) – 在 forward 的时候是否 prefetch 下一个 FSDP 的参数,默认是 False

  • fsdp_group_size (int) – FSDP 一个分组的大小(all_gather 的 rank 数量),FSDP 不同分组之间做数据并行;如果是 None 则全部 rank 都在同一个 FSDP 分组;默认是 None

Examples:

训练:

from hfai.distributed.fsdp import FullyShardedDataParallel

model = MyModel().cuda()
model = FullyShardedDataParallel(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)

# training ...
for epoch in range(epochs):
    for step, (x, y) in enumerate(dataloader):
        # training
        optimizer.zero_grad()
        output = model(x)
        loss_fn(y, output).backward()
        optimizer.step()

保存完整的 checkpoint:

model = FullyShardedDataParallel(model)

with model.summon_full_params():
    if rank == 0:
        state = model.state_dict()
        torch.save(state, 'model.pt')

# FSDP 模型加载完整的 checkpoint
state = torch.load('model.pt', map_location='cpu')
with model.summon_full_params():
    model.load_state_dict(state)

# 非 FSDP 模型也可以直接加载
model2 = MyModel().cuda()
model2.load_state_dict(state)

每个 GPU 都保存自己本地的 checkpoint(保存速度更快):

# 每个 GPU 上的 FSDP 都保存一份 checkpoint
model = FullyShardedDataParallel(model)
state = model.state_dict()
rank = dist.get_rank()
torch.save(state, f'model{rank}.pt')

# 每个 GPU 上的 FSDP 都需要加载
state = torch.load(f'model{rank}.pt', map_location='cpu')
model.load_state_dict(state)
summon_full_params()[source]

all_gather 完整的参数,方便用户保存完整的 checkpoint。 保存后的 checkpoint 等价于 fsdp_model.module.state_dict()

Examples:

from hfai.distributed.fsdp import FullyShardedDataParallel

model = MyModel().cuda()
model = FullyShardedDataParallel(model)

# save checkpoint
with model.summon_full_params():
    if rank == 0:
        state = model.state_dict()
        torch.save(state, 'model.pt')

# load checkpoint
state = torch.load('model.pt', map_location='cpu')
with model.summon_full_params():
    model.load_state_dict(state)

# this could also work
model2 = MyModel().cuda()
model2.load_state_dict(state)