Source code for hfai.checkpoint.auto_ckpt
from collections.abc import Iterable
import types
import shutil
from pathlib import Path
import torch
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.cuda.amp import GradScaler
import torch.distributed as dist
import hfai.client
from .dist_ckpt import save, load
[docs]def init(model, optimizer=None, *, scheduler=None, amp_scaler=None, group=None, ckpt_path):
"""
从给定的 checkpoint 中加载 model, optimizer, scheduler 等对象的状态。
本函数还会向 model 中添加一个 ``try_save`` 成员函数,用于保存训练状态。
``try_save`` 函数接受 epoch 和 step 作为输入,如果接收到 suspend 的信号
会保存模型、优化器等状态到 ``ckpt_path`` 中。
Args:
model (Module): PyTorch 模型
optimizer (Optimizer, Iterable[Optimizer]): PyTorch 优化器,可以是一个包含了多个优化器的迭代器对象;默认是 ``None``
scheduler (LRScheduler): PyTorch scheduler, 默认是 ``None``
amp_scaler (GradScaler): torch.cuda.amp.GradScaler 对象,默认是 ``None``
group (ProcessGroup): ProcessGroup 对象,默认是 ``None``
ckpt_path (str): checkpoint 路径
Returns:
epoch, step, others (int, int, Any): 上一次调用 ``try_save`` 时设置的 epoch、step 和其他保存的一些信息
Examples:
.. code-block:: python
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
start_epoch, start_step, others = hfai.checkpoint.init(model, optimizer, ckpt_path='latest.pt')
for epoch in range(start_epoch, epochs):
for step, (x, y) in enumerate(dataloader):
if step < start_step:
continue
output = model(x)
loss_fn(y, output).backward()
model.try_save(epoch, step, others=None)
"""
if not dist.is_initialized():
raise RuntimeError("torch.distributed is not initialized yet.")
assert isinstance(model, Module)
assert optimizer is None or isinstance(optimizer, (Iterable, Optimizer))
assert scheduler is None or isinstance(scheduler, LRScheduler)
assert amp_scaler is None or isinstance(amp_scaler, GradScaler)
assert group is None or isinstance(group, dist.ProcessGroup)
# TODO: do we need multiple models?
if not isinstance(optimizer, Optimizer) and len(optimizer) == 1:
optimizer = optimizer[0]
group = group or dist.distributed_c10d._get_default_group()
model._ckpt_info = [optimizer, scheduler, amp_scaler, str(ckpt_path), group]
model.try_save = types.MethodType(_try_save, model)
p = Path(ckpt_path)
if not p.exists():
return 0, 0, None
state = load(p, map_location='cpu')
epoch, step = state['epoch'], state['step']
model.load_state_dict(state['model'])
if optimizer is not None:
if isinstance(optimizer, Optimizer):
optimizer.load_state_dict(state['optimizer'])
else:
opt_states = state['optimizer']
for i, opt in enumerate(optimizer):
opt.load_state_dict(opt_states[i])
if scheduler is not None:
scheduler.load_state_dict(state['scheduler'])
if amp_scaler is not None:
amp_scaler.load_state_dict(state['amp_scaler'])
others = state.get('others', None)
rank = dist.get_rank(group=group)
if rank == 0:
print(f"Resume from epoch {epoch}, step {step}", flush=True)
return epoch, step, others
def _try_save(self, epoch, step, others=None, force=False):
"""
如果收到 suspend 信号,保存训练状态
Args:
epoch (int): 当前训练到哪个 epoch
step (int): 当前训练到哪个 step
others (Any): 其他需要保存下来的信息,默认是 ``None``
force (bool): 不管是否收到 suspend 信号,都保存 checkpoint;默认是 ``False``
"""
optimizer, scheduler, amp_scaler, ckpt_path, group = self._ckpt_info
# 只有 node-0 才会接收到信号,所以我们这里做一次 broadcast
receive_suspend = hfai.client.receive_suspend_command()
signal = torch.tensor(receive_suspend).bool().cuda()
# FIXME: 只有第一台机器能收到信号
root = 0
if group is not dist.distributed_c10d._get_default_group():
root = dist.distributed_c10d._get_global_rank(group, 0)
assert root < torch.cuda.device_count(), "rank-0 必须在第一台机器上"
dist.broadcast(signal, src=root, group=group)
receive_suspend = signal.item()
if not receive_suspend and not force:
return
state = {'epoch': epoch, 'step': step, 'others': others}
if scheduler is not None:
state['scheduler'] = scheduler.state_dict()
if amp_scaler is not None:
state['amp_scaler'] = amp_scaler.state_dict()
# 移形换影(防止来不及保存,污染旧的文件)
# 1. 新的 checkpoint -> xxx.tmp
# 2. 旧的 checkpoint -> xxx.old
# 3. xxx.tmp -> xxxx
# 4. 删除 xxx.old
save(ckpt_path + '.tmp', self, optimizer, others=state, group=group)
rank = dist.get_rank(group=group)
if rank == 0:
if Path(ckpt_path).exists():
shutil.move(ckpt_path, ckpt_path + ".old")
shutil.move(ckpt_path + ".tmp", ckpt_path)
if Path(ckpt_path + ".old").exists():
shutil.rmtree(ckpt_path + ".old")
# 其他 rank 有可能还没保存完,这里我们保证函数返回前所有 rank 已经完全写入
dist.barrier(group=group)
if rank == 0:
print(f"epoch {epoch}, step {step}: saved model to {ckpt_path}", flush=True)
if not receive_suspend:
return
if rank == 0:
print(f"epoch {epoch}, step {step}: going to suspend...", flush=True)
hfai.client.go_suspend()
dist.barrier(group=group)