hfai.checkpoint¶
从给定的 checkpoint 中加载 model, optimizer, scheduler 等对象的状态。 |
|
加载通过 hfai.checkpoint.save 保存的 checkpoint |
|
该函数把 checkpoint 切分成多份,每个 rank 保存一份数据,从而加快保存 checkpoint 的速度。 |
- hfai.checkpoint.init(model, optimizer=None, *, scheduler=None, amp_scaler=None, group=None, ckpt_path)[source]¶
从给定的 checkpoint 中加载 model, optimizer, scheduler 等对象的状态。
本函数还会向 model 中添加一个
try_save
成员函数,用于保存训练状态。try_save
函数接受 epoch 和 step 作为输入,如果接收到 suspend 的信号 会保存模型、优化器等状态到ckpt_path
中。- Parameters
model (Module) – PyTorch 模型
optimizer (Optimizer, Iterable[Optimizer]) – PyTorch 优化器,可以是一个包含了多个优化器的迭代器对象;默认是
None
scheduler (LRScheduler) – PyTorch scheduler, 默认是
None
amp_scaler (GradScaler) – torch.cuda.amp.GradScaler 对象,默认是
None
group (ProcessGroup) – ProcessGroup 对象,默认是
None
ckpt_path (str) – checkpoint 路径
- Returns
上一次调用
try_save
时设置的 epoch、step 和其他保存的一些信息- Return type
epoch, step, others (int, int, Any)
Examples:
model = MyModel() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) start_epoch, start_step, others = hfai.checkpoint.init(model, optimizer, ckpt_path='latest.pt') for epoch in range(start_epoch, epochs): for step, (x, y) in enumerate(dataloader): if step < start_step: continue output = model(x) loss_fn(y, output).backward() model.try_save(epoch, step, others=None)
- hfai.checkpoint.load(fname, nthreads=8, **kwargs)[source]¶
加载通过 hfai.checkpoint.save 保存的 checkpoint
- Parameters
fname (str, os.PathLike) – 保存的文件位置
nthreads (int) – 读取 checkpoint 的线程数,默认是
8
**kwargs – 传给
torch.load
的参数
- Returns
加载上来的 checkpoint
- Return type
state (dict)
Examples:
from hfai.checkpoint import save, load others = {'epoch': epoch, 'step': step + 1} save('latest.pt', model, optimizer, others=others) # 恢复训练 state = load('latest.pt', map_location='cpu') epoch, step = state['epoch'], state['step'] model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer'])
- hfai.checkpoint.save(fname, model, optimizer, others, group=None)[source]¶
该函数把 checkpoint 切分成多份,每个 rank 保存一份数据,从而加快保存 checkpoint 的速度。
- Parameters
fname (str, os.PathLike) – 保存的文件位置
model (Module, List[Module]) – PyTorch 模型,可以是包含多个模型对象的
list
optimizer (Optimizer, List[Optimizer]) – 优化器,可以是包含多个优化器对象的
list
,如果是 None 则忽略,默认是None
others (dict) – 其他需要保存的一些信息,默认是
None
group (ProcessGroup) – ProcessGroup 对象,默认是
None
Examples:
from hfai.checkpoint import save, load model = DistributedDataParallel(model, device_ids=[local_rank]) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # training ... for epoch in range(epochs): for step, data in enumerate(dataloader): # training others = {'epoch': epoch, 'step': step + 1} if receive_suspend: save('latest.pt', model, optimizer, others=others) # 恢复训练 state = load('latest.pt', map_location='cpu') epoch, step = state['epoch'], state['step'] model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer'])
Note
模型的 buffer 只会保存 rank-0 的