Shortcuts

Source code for hfai.pl.callbacks.model_checkpoint_hf

import os
import pytorch_lightning
import time
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from hfai.pl.utilities import _HFAI_AVAILABLE

if _HFAI_AVAILABLE: import hfai.client


[docs]class ModelCheckpointHF(ModelCheckpoint): """ 这是一个可以自动处理 Hfai 打断信号,自动挂起任务的 checkpoint 回调函数管理类, 支持 ``1.6.0 <= pytorch_lightning.__version__ <= 1.7.6`` Args: dirpath (str): 模型文件的保存目录(默认为 ``None``) filename (str): 模型文件的保存名称,例如 ``{epoch}-{val_loss:.2f}-{other_metric:.2f}``(默认为 ``None``) monitor (str): 监测的指标名称(默认为 ``None``) verbose (bool): 输出状态(默认为 ``False``) save_last (bool): 是否保存最后一个模型文件(默认为 ``None``) save_top_k (int): 保存前 ``k`` 好的模型文件,``k`` 为 ``0`` 时不保存,``k`` 为 ``-1`` 时保存所有模型文件(默认为 ``1``) save_weights_only (bool): 是否仅保存模型的权重(默认为 ``False``) mode (str): 指标的排序方式,包括:从大到小(``max``)或者从小到大(``min``),(默认为 ``min``) auto_insert_metric_name (bool): 是否在模型名称上自动插入指标的数值(默认为 ``True``) every_n_train_steps (int): 保存模型文件的训练间隔 step 数量(默认为 ``None``),不能和 ``train_time_interval`` 和 ``every_n_epochs`` 一同使用 train_time_interval (timedelta): 保存模型文件的训练间隔时间(默认为 ``None``),不能和 ``every_n_train_steps`` 和 ``every_n_epochs`` 一同使用 every_n_epochs (int): 保存模型文件的训练间隔 epoch 数量(默认为 ``None``),不能和 ``every_n_train_steps`` 和 ``train_time_interval`` 一同使用 save_on_train_epoch_end (bool): 是否在训练 epoch 时保存模型文件(默认为 ``None``) Raises: MisconfigurationException: 如果 ``save_top_k`` 比 ``-1``小 如果 ``monitor`` 不是 ``None`` 同时 ``save_top_k`` 不是 ``None``、``-1``、``0`` 如果 ``mode`` 不是 ``"min"`` 或者 ``"max"`` ValueError: 如果 ``trainer.save_checkpoint`` 是 ``None`` Examples: .. code-block:: python from hfai.pl import ModelCheckpointHF output_dir = 'hfai_out' checkpoint_callback = ModelCheckpointHF(dirpath=output_dir) # 第一步:定义 checkpoint_callback trainer = pytorch_lightning.Trainer( max_epochs=3, gpus=8, strategy="ddp_bind_numa", # hfai 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa plugins=[HFAIEnvironment()], callbacks=[checkpoint_callback] # 第二步:将 checkpoint_callback 输入到 trainer ) model_module = ToyNetModule() hfai_suspend_ckpt_path = f'{output_dir}/{checkpoint_callback.CHECKPOINT_NAME_SUSPEND}.ckpt' hfai_suspend_ckpt_path = hfai_suspend_ckpt_path if os.path.exists(hfai_suspend_ckpt_path) else None trainer.fit( model_module, ckpt_path=hfai_suspend_ckpt_path # 第三步:重启后载入打断前的最新模型 ) """ CHECKPOINT_NAME_SUSPEND = 'hfai_latest' def hfai_suspend(self, trainer: "pytorch_lightning.Trainer"): return _HFAI_AVAILABLE and trainer.global_rank == 0 and hfai.client.receive_suspend_command() def _save_hfai_suspend_checkpoint(self, trainer: "pytorch_lightning.Trainer", force_save: bool = False) -> None: if not force_save: if not self.hfai_suspend(trainer): return if pytorch_lightning.__version__ >= '1.6.0': monitor_candidates = self._monitor_candidates(trainer) filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_SUSPEND) _checkpoint = trainer._checkpoint_connector.dump_checkpoint() if not os.path.exists(os.path.dirname(filepath)): os.makedirs(os.path.dirname(filepath)) torch.save(_checkpoint, filepath) else: monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step - 1) filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_SUSPEND) trainer.save_checkpoint(filepath) if self.hfai_suspend(trainer): print(f'Receive suspend command. Now save checkpoint and go suspend! ' f'Global rank {trainer.global_rank}. Save checkpoint to {filepath}') time.sleep(3) hfai.client.go_suspend() def on_train_epoch_end(self, *args, **kwargs) -> None: trainer = kwargs.get('trainer', None) or args[0] self._save_hfai_suspend_checkpoint(trainer, force_save=trainer.global_rank == 0) super().on_train_epoch_end(*args, **kwargs) def on_validation_epoch_end(self, *args, **kwargs) -> None: trainer = kwargs.get('trainer', None) or args[0] self._save_hfai_suspend_checkpoint(trainer) super().on_validation_epoch_end(*args, **kwargs) def on_train_batch_end(self, *args, **kwargs) -> None: trainer = kwargs.get('trainer', None) or args[0] self._save_hfai_suspend_checkpoint(trainer) super().on_train_batch_end(*args, **kwargs) def on_validation_batch_end(self, *args, **kwargs) -> None: trainer = kwargs.get('trainer', None) or args[0] self._save_hfai_suspend_checkpoint(trainer) super().on_validation_batch_end(*args, **kwargs)