Source code for hfai.utils.detr2.checkpointer
import hfai
import torch.distributed as dist
from detectron2.engine import HookBase
[docs]class SuspendCheckpointer(HookBase):
def __init__(
self,
checkpointer,
file_prefix: str = "model",
) -> None:
self.checkpointer = checkpointer
self.last_checkpoint = None
self.path_manager = checkpointer.path_manager
self.file_prefix = file_prefix
def step(self, iteration: int) -> None:
iteration = int(iteration)
additional_state = {"iteration": iteration}
if dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
if rank == 0 and hfai.client.receive_suspend_command():
self.checkpointer.save(
"{}_latest.hfai".format(self.file_prefix), **additional_state
)
print("[HFAI SuspendCheckpointer] going to suspend...", flush=True)
hfai.client.go_suspend()
def save(self, name: str, **kwargs) -> None:
self.checkpointer.save(name, **kwargs)
def after_step(self):
# No way to use **kwargs
self.step(self.trainer.iter)