Shortcuts

Source code for hfai.checkpoint.dist_ckpt

from pathlib import Path
from collections import defaultdict
from multiprocessing.pool import ThreadPool

import torch
import torch.distributed as dist
from torch.nn import Module
from torch.optim import Optimizer

from .utils import split_dict, check_type


LARGE_SIZE = 256 * (1 << 20)  # 256 MB
VERSION = "2.0.0"


[docs]def save(fname, model, optimizer, others, group=None) -> None: """ 该函数把 checkpoint 切分成多份,每个 rank 保存一份数据,从而加快保存 checkpoint 的速度。 Args: fname (str, os.PathLike): 保存的文件位置 model (Module, List[Module]): PyTorch 模型,可以是包含多个模型对象的 ``list`` optimizer (Optimizer, List[Optimizer]): 优化器,可以是包含多个优化器对象的 ``list``,如果是 None 则忽略,默认是 ``None`` others (dict): 其他需要保存的一些信息,默认是 ``None`` group (ProcessGroup): ProcessGroup 对象,默认是 ``None`` Examples: .. code-block:: python from hfai.checkpoint import save, load model = DistributedDataParallel(model, device_ids=[local_rank]) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # training ... for epoch in range(epochs): for step, data in enumerate(dataloader): # training others = {'epoch': epoch, 'step': step + 1} if receive_suspend: save('latest.pt', model, optimizer, others=others) # 恢复训练 state = load('latest.pt', map_location='cpu') epoch, step = state['epoch'], state['step'] model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) NOTE: 模型的 buffer 只会保存 rank-0 的 """ if not dist.is_initialized(): raise RuntimeError("torch.distributed is not initialized yet.") if others is not None and not isinstance(others, dict): raise TypeError(f"`others` could only be None or a dict") assert group is None or isinstance(group, dist.ProcessGroup) models = check_type("model", model, Module) optimizers = check_type("optimizer", optimizer, Optimizer) others = others or {} for n in others: assert n not in ["model", "optimizer"], n group = group or dist.distributed_c10d._get_default_group() rank = dist.get_rank(group=group) nshards = dist.get_world_size(group=group) state = {"__nshards__": nshards, "__version__": VERSION} sharder = ModelSharder(group) state['model'] = [sharder.apply(model) for model in models] sharder = OptimizerSharder(group) state['optimizer'] = [sharder.apply(opt) for opt in optimizers] # save others by rank-0 if rank == 0: state.update(others) # write to the filesystem output_dir = Path(fname) output_dir.mkdir(parents=True, exist_ok=True) torch.save(state, output_dir / f"PART_{rank:03d}.pt") dist.barrier(group=group)
[docs]def load(fname, nthreads=8, **kwargs) -> dict: """ 加载通过 `hfai.checkpoint.save` 保存的 checkpoint Args: fname (str, os.PathLike): 保存的文件位置 nthreads (int): 读取 checkpoint 的线程数,默认是 ``8`` **kwargs: 传给 ``torch.load`` 的参数 Returns: state (dict): 加载上来的 checkpoint Examples: .. code-block:: python from hfai.checkpoint import save, load others = {'epoch': epoch, 'step': step + 1} save('latest.pt', model, optimizer, others=others) # 恢复训练 state = load('latest.pt', map_location='cpu') epoch, step = state['epoch'], state['step'] model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) """ if nthreads < 0: raise ValueError(f"nthreads must be >= 0, but found {nthreads}") ckpt_dir = Path(fname) assert ckpt_dir.is_dir() nshards = len(list(ckpt_dir.glob('PART_*.pt'))) states = [None for _ in range(nshards)] # read if nthreads > 0: def work(i): states[i] = torch.load(ckpt_dir / f"PART_{i:03d}.pt", **kwargs) with ThreadPool(nthreads) as pool: pool.map(work, range(nshards)) state0 = states[0] else: state0 = torch.load(ckpt_dir / f"PART_000.pt", **kwargs) version = state0.get("__version__", "1.0.0") assert version == VERSION assert nshards == state0["__nshards__"] models = [ModelLoader() for _ in range(len(state0['model']))] opts = [OptimizerLoader() for _ in range(len(state0['optimizer']))] # concat for i in range(0, nshards): if i == 0: state = state0 else: if nthreads > 1: state = states[i] states[i] = None else: state = torch.load(ckpt_dir / f"PART_{i:03d}.pt", **kwargs) assert nshards == state["__nshards__"] for s, model in zip(state['model'], models): model.append(*s) for s, opt in zip(state['optimizer'], opts): opt.append(*s) models = [model.finalize() for model in models] opts = [opt.finalize() for opt in opts] if len(models) == 1: models = models[0] if len(opts) == 1: opts = opts[0] if len(opts) == 0: opts = None state0['model'] = models if opts is not None: state0['optimizer'] = opts return state0
class ModelSharder(): def __init__(self, group): self.rank = dist.get_rank(group=group) self.nshards = dist.get_world_size(group=group) def apply(self, model): params = model.state_dict() buffers = {name: params.pop(name) for name, buf in model.named_buffers()} large_params = self.collect_large_params(params) params = split_dict(params, self.rank, self.nshards) if self.rank == 0: params.update(buffers) return params, large_params def collect_large_params(self, params): large_params = {} for name in list(params.keys()): param = params[name] size = param.numel() * param.element_size() if size > LARGE_SIZE and param.layout == torch.strided: # 这里得用 clone(), 否则实际上会存整个 tensor new_p = param.view(-1).chunk(self.nshards)[self.rank].clone() large_params[name] = (param.shape, new_p) del params[name] return large_params class ModelLoader(): """ Helper class for loading sharded model checkpoint """ def __init__(self): self.large_params = defaultdict(list) self.large_params_shape = {} self.params = {} def append(self, params, large_params): self.params.update(params) for name in large_params: shape, param = large_params[name] self.large_params[name].append(param) self.large_params_shape[name] = shape def finalize(self): for name in list(self.large_params.keys()): assert name not in self.params param = self.large_params[name] shape = self.large_params_shape[name] param = torch.cat(param, dim=0).view(*shape) self.params[name] = param del self.large_params[name] params = self.params self.params = None self.large_params = None return params class OptimizerSharder(): def __init__(self, group): self.rank = dist.get_rank(group=group) self.nshards = dist.get_world_size(group=group) def apply(self, optimizer): opt = optimizer.state_dict().copy() state = opt["state"].copy() large_tensors = self.collect_large_tensors(state) opt["state"] = split_dict(state, self.rank, self.nshards) return opt, large_tensors def collect_large_tensors(self, state): large_tensors = {} for k1 in list(state.keys()): if not isinstance(state[k1], dict): continue for k2 in list(state[k1].keys()): x = state[k1][k2] if not isinstance(x, torch.Tensor): continue size = x.numel() * x.element_size() if size > LARGE_SIZE and x.layout == torch.strided: # 这里得用 clone(), 否则实际上会存整个 tensor new_x = x.view(-1).chunk(self.nshards)[self.rank].clone() large_tensors[(k1, k2)] = (x.shape, new_x) state[k1] = state[k1].copy() del state[k1][k2] return large_tensors class OptimizerLoader(): def __init__(self): self.large_tensors = defaultdict(list) self.large_tensors_shape = {} self.opt = {} def append(self, opt, large_tensors): if len(self.opt) == 0: self.opt = opt else: self.opt["state"].update(opt["state"]) for key in large_tensors: shape, tensor = large_tensors[key] self.large_tensors[key].append(tensor) self.large_tensors_shape[key] = shape def finalize(self): for key in list(self.large_tensors.keys()): k1, k2 = key assert k2 not in self.opt["state"][k1] tensors = self.large_tensors[key] shape = self.large_tensors_shape[key] tensor = torch.cat(tensors, dim=0).view(*shape) self.opt["state"][k1][k2] = tensor del self.large_tensors[key] opt = self.opt self.opt = None self.large_tensors = None return opt