Shortcuts

hfai.checkpoint

init

从给定的 checkpoint 中加载 model, optimizer, scheduler 等对象的状态。

load

加载通过 hfai.checkpoint.save 保存的 checkpoint

save

该函数把 checkpoint 切分成多份,每个 rank 保存一份数据,从而加快保存 checkpoint 的速度。

hfai.checkpoint.init(model, optimizer=None, *, scheduler=None, amp_scaler=None, group=None, ckpt_path)[source]

从给定的 checkpoint 中加载 model, optimizer, scheduler 等对象的状态。

本函数还会向 model 中添加一个 try_save 成员函数,用于保存训练状态。 try_save 函数接受 epoch 和 step 作为输入,如果接收到 suspend 的信号 会保存模型、优化器等状态到 ckpt_path 中。

Parameters
  • model (Module) – PyTorch 模型

  • optimizer (Optimizer, Iterable[Optimizer]) – PyTorch 优化器,可以是一个包含了多个优化器的迭代器对象;默认是 None

  • scheduler (LRScheduler) – PyTorch scheduler, 默认是 None

  • amp_scaler (GradScaler) – torch.cuda.amp.GradScaler 对象,默认是 None

  • group (ProcessGroup) – ProcessGroup 对象,默认是 None

  • ckpt_path (str) – checkpoint 路径

Returns

上一次调用 try_save 时设置的 epoch、step 和其他保存的一些信息

Return type

epoch, step, others (int, int, Any)

Examples:

model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

start_epoch, start_step, others = hfai.checkpoint.init(model, optimizer, ckpt_path='latest.pt')
for epoch in range(start_epoch, epochs):
    for step, (x, y) in enumerate(dataloader):
        if step < start_step:
            continue

        output = model(x)
        loss_fn(y, output).backward()
        model.try_save(epoch, step, others=None)
hfai.checkpoint.load(fname, nthreads=8, **kwargs)[source]

加载通过 hfai.checkpoint.save 保存的 checkpoint

Parameters
  • fname (str, os.PathLike) – 保存的文件位置

  • nthreads (int) – 读取 checkpoint 的线程数,默认是 8

  • **kwargs – 传给 torch.load 的参数

Returns

加载上来的 checkpoint

Return type

state (dict)

Examples:

from hfai.checkpoint import save, load

others = {'epoch': epoch, 'step': step + 1}
save('latest.pt', model, optimizer, others=others)

# 恢复训练
state = load('latest.pt', map_location='cpu')
epoch, step = state['epoch'], state['step']
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
hfai.checkpoint.save(fname, model, optimizer, others, group=None)[source]

该函数把 checkpoint 切分成多份,每个 rank 保存一份数据,从而加快保存 checkpoint 的速度。

Parameters
  • fname (str, os.PathLike) – 保存的文件位置

  • model (Module, List[Module]) – PyTorch 模型,可以是包含多个模型对象的 list

  • optimizer (Optimizer, List[Optimizer]) – 优化器,可以是包含多个优化器对象的 list,如果是 None 则忽略,默认是 None

  • others (dict) – 其他需要保存的一些信息,默认是 None

  • group (ProcessGroup) – ProcessGroup 对象,默认是 None

Examples:

from hfai.checkpoint import save, load

model = DistributedDataParallel(model, device_ids=[local_rank])
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# training ...
for epoch in range(epochs):
    for step, data in enumerate(dataloader):
        # training

        others = {'epoch': epoch, 'step': step + 1}

        if receive_suspend:
            save('latest.pt', model, optimizer, others=others)

# 恢复训练
state = load('latest.pt', map_location='cpu')
epoch, step = state['epoch'], state['step']
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])

Note

模型的 buffer 只会保存 rank-0 的