Source code for hfai.pl.callbacks.model_checkpoint_hf
import os
import pytorch_lightning
import time
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from hfai.pl.utilities import _HFAI_AVAILABLE
if _HFAI_AVAILABLE: import hfai.client
[docs]class ModelCheckpointHF(ModelCheckpoint):
"""
这是一个可以自动处理 Hfai 打断信号,自动挂起任务的 checkpoint 回调函数管理类, 支持 ``1.6.0 <= pytorch_lightning.__version__ <= 1.7.6``
Args:
dirpath (str): 模型文件的保存目录(默认为 ``None``)
filename (str): 模型文件的保存名称,例如 ``{epoch}-{val_loss:.2f}-{other_metric:.2f}``(默认为 ``None``)
monitor (str): 监测的指标名称(默认为 ``None``)
verbose (bool): 输出状态(默认为 ``False``)
save_last (bool): 是否保存最后一个模型文件(默认为 ``None``)
save_top_k (int): 保存前 ``k`` 好的模型文件,``k`` 为 ``0`` 时不保存,``k`` 为 ``-1`` 时保存所有模型文件(默认为 ``1``)
save_weights_only (bool): 是否仅保存模型的权重(默认为 ``False``)
mode (str): 指标的排序方式,包括:从大到小(``max``)或者从小到大(``min``),(默认为 ``min``)
auto_insert_metric_name (bool): 是否在模型名称上自动插入指标的数值(默认为 ``True``)
every_n_train_steps (int): 保存模型文件的训练间隔 step 数量(默认为 ``None``),不能和 ``train_time_interval`` 和 ``every_n_epochs`` 一同使用
train_time_interval (timedelta): 保存模型文件的训练间隔时间(默认为 ``None``),不能和 ``every_n_train_steps`` 和 ``every_n_epochs`` 一同使用
every_n_epochs (int): 保存模型文件的训练间隔 epoch 数量(默认为 ``None``),不能和 ``every_n_train_steps`` 和 ``train_time_interval`` 一同使用
save_on_train_epoch_end (bool): 是否在训练 epoch 时保存模型文件(默认为 ``None``)
Raises:
MisconfigurationException:
如果 ``save_top_k`` 比 ``-1``小
如果 ``monitor`` 不是 ``None`` 同时 ``save_top_k`` 不是 ``None``、``-1``、``0``
如果 ``mode`` 不是 ``"min"`` 或者 ``"max"``
ValueError:
如果 ``trainer.save_checkpoint`` 是 ``None``
Examples:
.. code-block:: python
from hfai.pl import ModelCheckpointHF
output_dir = 'hfai_out'
checkpoint_callback = ModelCheckpointHF(dirpath=output_dir) # 第一步:定义 checkpoint_callback
trainer = pytorch_lightning.Trainer(
max_epochs=3,
gpus=8,
strategy="ddp_bind_numa", # hfai 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa
plugins=[HFAIEnvironment()],
callbacks=[checkpoint_callback] # 第二步:将 checkpoint_callback 输入到 trainer
)
model_module = ToyNetModule()
hfai_suspend_ckpt_path = f'{output_dir}/{checkpoint_callback.CHECKPOINT_NAME_SUSPEND}.ckpt'
hfai_suspend_ckpt_path = hfai_suspend_ckpt_path if os.path.exists(hfai_suspend_ckpt_path) else None
trainer.fit(
model_module,
ckpt_path=hfai_suspend_ckpt_path # 第三步:重启后载入打断前的最新模型
)
"""
CHECKPOINT_NAME_SUSPEND = 'hfai_latest'
def hfai_suspend(self, trainer: "pytorch_lightning.Trainer"):
return _HFAI_AVAILABLE and trainer.global_rank == 0 and hfai.client.receive_suspend_command()
def _save_hfai_suspend_checkpoint(self, trainer: "pytorch_lightning.Trainer", force_save: bool = False) -> None:
if not force_save:
if not self.hfai_suspend(trainer):
return
if pytorch_lightning.__version__ >= '1.6.0':
monitor_candidates = self._monitor_candidates(trainer)
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_SUSPEND)
_checkpoint = trainer._checkpoint_connector.dump_checkpoint()
if not os.path.exists(os.path.dirname(filepath)):
os.makedirs(os.path.dirname(filepath))
torch.save(_checkpoint, filepath)
else:
monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step - 1)
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_SUSPEND)
trainer.save_checkpoint(filepath)
if self.hfai_suspend(trainer):
print(f'Receive suspend command. Now save checkpoint and go suspend! '
f'Global rank {trainer.global_rank}. Save checkpoint to {filepath}')
time.sleep(3)
hfai.client.go_suspend()
def on_train_epoch_end(self, *args, **kwargs) -> None:
trainer = kwargs.get('trainer', None) or args[0]
self._save_hfai_suspend_checkpoint(trainer, force_save=trainer.global_rank == 0)
super().on_train_epoch_end(*args, **kwargs)
def on_validation_epoch_end(self, *args, **kwargs) -> None:
trainer = kwargs.get('trainer', None) or args[0]
self._save_hfai_suspend_checkpoint(trainer)
super().on_validation_epoch_end(*args, **kwargs)
def on_train_batch_end(self, *args, **kwargs) -> None:
trainer = kwargs.get('trainer', None) or args[0]
self._save_hfai_suspend_checkpoint(trainer)
super().on_train_batch_end(*args, **kwargs)
def on_validation_batch_end(self, *args, **kwargs) -> None:
trainer = kwargs.get('trainer', None) or args[0]
self._save_hfai_suspend_checkpoint(trainer)
super().on_validation_batch_end(*args, **kwargs)