Shortcuts

Source code for hfai.checkpoint.auto_ckpt

from collections.abc import Iterable
import types
import shutil
from pathlib import Path

import torch
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.cuda.amp import GradScaler
import torch.distributed as dist

import hfai.client

from .dist_ckpt import save, load


[docs]def init(model, optimizer=None, *, scheduler=None, amp_scaler=None, group=None, ckpt_path): """ 从给定的 checkpoint 中加载 model, optimizer, scheduler 等对象的状态。 本函数还会向 model 中添加一个 ``try_save`` 成员函数,用于保存训练状态。 ``try_save`` 函数接受 epoch 和 step 作为输入,如果接收到 suspend 的信号 会保存模型、优化器等状态到 ``ckpt_path`` 中。 Args: model (Module): PyTorch 模型 optimizer (Optimizer, Iterable[Optimizer]): PyTorch 优化器,可以是一个包含了多个优化器的迭代器对象;默认是 ``None`` scheduler (LRScheduler): PyTorch scheduler, 默认是 ``None`` amp_scaler (GradScaler): torch.cuda.amp.GradScaler 对象,默认是 ``None`` group (ProcessGroup): ProcessGroup 对象,默认是 ``None`` ckpt_path (str): checkpoint 路径 Returns: epoch, step, others (int, int, Any): 上一次调用 ``try_save`` 时设置的 epoch、step 和其他保存的一些信息 Examples: .. code-block:: python model = MyModel() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) start_epoch, start_step, others = hfai.checkpoint.init(model, optimizer, ckpt_path='latest.pt') for epoch in range(start_epoch, epochs): for step, (x, y) in enumerate(dataloader): if step < start_step: continue output = model(x) loss_fn(y, output).backward() model.try_save(epoch, step, others=None) """ if not dist.is_initialized(): raise RuntimeError("torch.distributed is not initialized yet.") assert isinstance(model, Module) assert optimizer is None or isinstance(optimizer, (Iterable, Optimizer)) assert scheduler is None or isinstance(scheduler, LRScheduler) assert amp_scaler is None or isinstance(amp_scaler, GradScaler) assert group is None or isinstance(group, dist.ProcessGroup) # TODO: do we need multiple models? if not isinstance(optimizer, Optimizer) and len(optimizer) == 1: optimizer = optimizer[0] group = group or dist.distributed_c10d._get_default_group() model._ckpt_info = [optimizer, scheduler, amp_scaler, str(ckpt_path), group] model.try_save = types.MethodType(_try_save, model) p = Path(ckpt_path) if not p.exists(): return 0, 0, None state = load(p, map_location='cpu') epoch, step = state['epoch'], state['step'] model.load_state_dict(state['model']) if optimizer is not None: if isinstance(optimizer, Optimizer): optimizer.load_state_dict(state['optimizer']) else: opt_states = state['optimizer'] for i, opt in enumerate(optimizer): opt.load_state_dict(opt_states[i]) if scheduler is not None: scheduler.load_state_dict(state['scheduler']) if amp_scaler is not None: amp_scaler.load_state_dict(state['amp_scaler']) others = state.get('others', None) rank = dist.get_rank(group=group) if rank == 0: print(f"Resume from epoch {epoch}, step {step}", flush=True) return epoch, step, others
def _try_save(self, epoch, step, others=None, force=False): """ 如果收到 suspend 信号,保存训练状态 Args: epoch (int): 当前训练到哪个 epoch step (int): 当前训练到哪个 step others (Any): 其他需要保存下来的信息,默认是 ``None`` force (bool): 不管是否收到 suspend 信号,都保存 checkpoint;默认是 ``False`` """ optimizer, scheduler, amp_scaler, ckpt_path, group = self._ckpt_info # 只有 node-0 才会接收到信号,所以我们这里做一次 broadcast receive_suspend = hfai.client.receive_suspend_command() signal = torch.tensor(receive_suspend).bool().cuda() # FIXME: 只有第一台机器能收到信号 root = 0 if group is not dist.distributed_c10d._get_default_group(): root = dist.distributed_c10d._get_global_rank(group, 0) assert root < torch.cuda.device_count(), "rank-0 必须在第一台机器上" dist.broadcast(signal, src=root, group=group) receive_suspend = signal.item() if not receive_suspend and not force: return state = {'epoch': epoch, 'step': step, 'others': others} if scheduler is not None: state['scheduler'] = scheduler.state_dict() if amp_scaler is not None: state['amp_scaler'] = amp_scaler.state_dict() # 移形换影(防止来不及保存,污染旧的文件) # 1. 新的 checkpoint -> xxx.tmp # 2. 旧的 checkpoint -> xxx.old # 3. xxx.tmp -> xxxx # 4. 删除 xxx.old save(ckpt_path + '.tmp', self, optimizer, others=state, group=group) rank = dist.get_rank(group=group) if rank == 0: if Path(ckpt_path).exists(): shutil.move(ckpt_path, ckpt_path + ".old") shutil.move(ckpt_path + ".tmp", ckpt_path) if Path(ckpt_path + ".old").exists(): shutil.rmtree(ckpt_path + ".old") # 其他 rank 有可能还没保存完,这里我们保证函数返回前所有 rank 已经完全写入 dist.barrier(group=group) if rank == 0: print(f"epoch {epoch}, step {step}: saved model to {ckpt_path}", flush=True) if not receive_suspend: return if rank == 0: print(f"epoch {epoch}, step {step}: going to suspend...", flush=True) hfai.client.go_suspend() dist.barrier(group=group)