Source code for hfai.utils.detr2.checkpointer

import hfai
import torch.distributed as dist
from detectron2.engine import HookBase

[docs]class SuspendCheckpointer(HookBase): def __init__( self, checkpointer, file_prefix: str = "model", ) -> None: self.checkpointer = checkpointer self.last_checkpoint = None self.path_manager = checkpointer.path_manager self.file_prefix = file_prefix def step(self, iteration: int) -> None: iteration = int(iteration) additional_state = {"iteration": iteration} if dist.is_initialized(): rank = dist.get_rank() else: rank = 0 if rank == 0 and hfai.client.receive_suspend_command(): "{}_latest.hfai".format(self.file_prefix), **additional_state ) print("[HFAI SuspendCheckpointer] going to suspend...", flush=True) hfai.client.go_suspend() def save(self, name: str, **kwargs) -> None:, **kwargs) def after_step(self): # No way to use **kwargs self.step(self.trainer.iter)