Shortcuts

Source code for haiscale.fsdp.fsdp

import contextlib
from enum import Enum, auto

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group

from .param import FlattenedParameter
from .communication import reduce_scatter_base
from .logger import FSDPLogger


# TODO: 检查这些假设
# 我们假设是
# [x] 1. fp32
# [x] 2. 没有共享 parameter
# [x] 3. 只做一次 forward

# TODO
# - [ ] 支持 gradient accumulation


class FSDPState(Enum):

    IDLE = auto()           # forward 之前,backward 之后
    IN_FORWARD = auto()     # 正在 forward
    POST_FORWARD = auto()   # forward 之后,backward 之前

    IN_BACKWARD = auto()    # 正在 backward
    PRE_BACKWARD = auto()   # 调用 pre_backward_hook
    POST_BACKWARD = auto()  # 调用 post_backward_hook

    UNFLATTEN_PARAM = auto()


class BackwardPrefetch(Enum):

    BACKWARD_PRE = auto()
    BACKWARD_POST = auto()


[docs]class FullyShardedDataParallel(nn.Module): """ FullyShardedDataParallel,使用方法和 PyTorch FSDP 类似。 Args: module (torch.nn.Module): 用户的 PyTorch 模型 process_group (ProcessGroup): 全局的 process group,默认是 ``None`` auto_wrap_policy (Callable[nn.Module, bool, int]): ``auto_wrap_policy`` 决定了参数如何分组,和 PyTorch FSDP 的同名参数作用一样;默认是 ``None``。 reshard_after_forward (bool): 是否在 forward 之后切分参数;默认是 ``True`` backward_prefetch (BackwardPrefetch): 在 backward 的时候如何 prefetch 下一个 FSDP 的参数,如果是 ``None`` 则不会做 prefetch;默认是 ``None``。 ``BackwardPrefetch.BACKWARD_PRE``: 在 backward 之前 prefetch 下一个 FSDP 的参数; ``BackwardPrefetch.BACKWARD_POST``: 在 backward 之后 prefetch 下一个 FSDP 的参数。 forward_prefetch (bool): 在 forward 的时候是否 prefetch 下一个 FSDP 的参数,默认是 ``False`` fsdp_group_size (int): FSDP 一个分组的大小(all_gather 的 rank 数量),FSDP 不同分组之间做数据并行; 如果是 None 则全部 rank 都在同一个 FSDP 分组;默认是 ``None`` Examples: 训练: .. code-block:: python from haiscale.fsdp import FullyShardedDataParallel model = MyModel().cuda() model = FullyShardedDataParallel(model) optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) # training ... for epoch in range(epochs): for step, (x, y) in enumerate(dataloader): # training optimizer.zero_grad() output = model(x) loss_fn(y, output).backward() optimizer.step() 保存完整的 checkpoint: .. code-block:: python model = FullyShardedDataParallel(model) with model.summon_full_params(): if rank == 0: state = model.state_dict() torch.save(state, 'model.pt') # FSDP 模型加载完整的 checkpoint state = torch.load('model.pt', map_location='cpu') with model.summon_full_params(): model.load_state_dict(state) # 非 FSDP 模型也可以直接加载 model2 = MyModel().cuda() model2.load_state_dict(state) 每个 GPU 都保存自己本地的 checkpoint(保存速度更快): .. code-block:: python # 每个 GPU 上的 FSDP 都保存一份 checkpoint model = FullyShardedDataParallel(model) state = model.state_dict() rank = dist.get_rank() torch.save(state, f'model{rank}.pt') # 每个 GPU 上的 FSDP 都需要加载 state = torch.load(f'model{rank}.pt', map_location='cpu') model.load_state_dict(state) """ FSDP_MODULE_NAME = '_fsdp_orig_module' ALIGN_BITS = 128 def __init__( self, module, process_group=None, auto_wrap_policy=None, reshard_after_forward=True, backward_prefetch=None, forward_prefetch=False, fsdp_group_size=None, _is_root=True, _fsdp_group=None, _ddp_group=None, ): super().__init__() assert isinstance(module, nn.Module) self.add_module(self.FSDP_MODULE_NAME, module) self.group = process_group or _get_default_group() if _fsdp_group is not None: self.fsdp_group = _fsdp_group self.ddp_group = _ddp_group else: assert fsdp_group_size is None or fsdp_group_size <= self.group.size() if fsdp_group_size is not None and fsdp_group_size < self.group.size(): assert self.group is _get_default_group() self.fsdp_group = self._create_fsdp_group(self.group, fsdp_group_size) self.ddp_group = self._create_ddp_group(self.group, fsdp_group_size) else: self.fsdp_group = self.group self.ddp_group = None # broadcast model from rank-0 self.is_root = _is_root if self.is_root: self._sync_params_and_buffers() if auto_wrap_policy is not None: try: from torch.distributed.fsdp.wrap import _recursive_wrap except ModuleNotFoundError as e: raise RuntimeError("auto_wrap_policy requires torch >= 1.12") _recursive_wrap( module, auto_wrap_policy, FullyShardedDataParallel, ignored_modules=set(), ignored_params=set(), only_wrap_children=True, # follow are FSDP kwargs process_group=process_group, reshard_after_forward=reshard_after_forward, forward_prefetch=forward_prefetch, backward_prefetch=backward_prefetch, fsdp_group_size=fsdp_group_size, _is_root=False, _fsdp_group=self.fsdp_group, _ddp_group=self.ddp_group, ) self.flat_param = self._flatten_parameters() self.flat_param.make_params_as_view() # 切分参数 self.flat_param.shard() self.reshard_after_forward = reshard_after_forward and (not self.is_root) self.post_backward_hooks = [] self.callback_queued = False self.pre_backward_hook_called = False self.forward_prefetch = forward_prefetch self.backward_prefetch = backward_prefetch self.fsdp_index = None self.state = FSDPState.IDLE # hooks for save/load state dict self._register_state_dict_hooks() if self.is_root: # 一些共享的变量 self.all_gather_stream = torch.cuda.Stream() self.reduce_scatter_stream = torch.cuda.Stream() self.fsdp_graph_order = [] self.logger = FSDPLogger(enable=False) for m in self.fsdp_modules(): m.fsdp_graph_order = self.fsdp_graph_order m.all_gather_stream = self.all_gather_stream m.reduce_scatter_stream = self.reduce_scatter_stream m.logger = self.logger def forward(self, *args, **kwargs): assert self.state == FSDPState.IDLE if self.is_root: self.logger.record_start("fwd") self.logger.record_start("pre_fwd") # 这里假设每次 forward 的顺序都是一样的 if self not in self.fsdp_graph_order: self.fsdp_index = len(self.fsdp_graph_order) self.fsdp_graph_order.append(self) if self.is_root: assert self.fsdp_index == 0, self.fsdp_index # 等待 optimizer.step() 完成 stream = torch.cuda.current_stream() if self.fsdp_index == 0: self.all_gather_stream.wait_stream(stream) # all-gather self.logger.record_start("fwd_allgather") self._rebuild_full_params() stream.wait_stream(self.all_gather_stream) if self.forward_prefetch and self.fsdp_index != len(self.fsdp_graph_order) - 1: next_fsdp = self.fsdp_graph_order[self.fsdp_index + 1] next_fsdp._rebuild_full_params() self.logger.record_end("fwd_allgather") # torch.split 之后如果有 in-place 的操作(比如 optimizer.step),原来的 view 就用不了了 # 所以我们重新 make view 一下 self.flat_param.make_params_as_view() self.logger.record_end("pre_fwd") # forward outputs = self.module(*args, **kwargs) self.logger.record_start("post_fwd") # 通过 all-gather 拿到完整数据之后我们才能调用 p.expand_as # 所以只能在 forward 的时候注册钩子,不能在构造函数里注册 self._register_post_backward_hooks() if self.reshard_after_forward: self.flat_param.shard() # 实际上并没有 pre-backward hook,只能在输出的 tensor 里加 hook self._register_pre_backward_hooks(outputs) self.state = FSDPState.POST_FORWARD if not torch.is_grad_enabled(): self.state = FSDPState.IDLE self.logger.record_end("post_fwd") if self.is_root: self.logger.record_end("fwd") return outputs ########################################## # utils ########################################## def fsdp_modules(self): for m in self.modules(): if isinstance(m, FullyShardedDataParallel): yield m @property def module(self): return getattr(self, self.FSDP_MODULE_NAME) def enable_logger(self, enable=True): self.logger.enable = enable def _sync_params_and_buffers(self): bucket_size = 250 * (1 << 20) # 250 MB module_states = list(self.module.parameters()) + list(self.module.buffers()) module_states = [p.data for p in module_states] dist._broadcast_coalesced(self.group, module_states, bucket_size, 0) def extra_repr(self) -> str: flat_param_size = self.flat_param.numel() return f'flattened parameter size: {flat_param_size}\n' \ f'root: {self.is_root}\n' \ f'fsdp_index: {self.fsdp_index}' def _create_ddp_group(self, group, fsdp_group_size): assert group.size() % fsdp_group_size == 0 ddp_group_size = group.size() // fsdp_group_size cur_fsdp_rank = group.rank() % fsdp_group_size cur_ddp_group = None for fsdp_rank in range(fsdp_group_size): ranks = [fsdp_rank + j * fsdp_group_size for j in range(ddp_group_size)] ddp_group = dist.new_group(ranks) if fsdp_rank == cur_fsdp_rank: cur_ddp_group = ddp_group assert cur_ddp_group is not None return cur_ddp_group def _create_fsdp_group(self, group, fsdp_group_size): assert group.size() % fsdp_group_size == 0 num_fsdp_groups = group.size() // fsdp_group_size cur_fsdp_group = None for i in range(num_fsdp_groups): ranks = list(range(i * fsdp_group_size, (i + 1) * fsdp_group_size)) fsdp_group = dist.new_group(ranks) if group.rank() // fsdp_group_size == i: cur_fsdp_group = fsdp_group assert cur_fsdp_group is not None return cur_fsdp_group ########################################## # save/load state_dict ##########################################
[docs] @contextlib.contextmanager def summon_full_params(self): """ all_gather 完整的参数,方便用户保存完整的 checkpoint。 保存后的 checkpoint 等价于 `fsdp_model.module.state_dict()` Examples: .. code-block:: python from haiscale.fsdp import FullyShardedDataParallel model = MyModel().cuda() model = FullyShardedDataParallel(model) # save checkpoint with model.summon_full_params(): if rank == 0: state = model.state_dict() torch.save(state, 'model.pt') # load checkpoint state = torch.load('model.pt', map_location='cpu') with model.summon_full_params(): model.load_state_dict(state) # this could also work model2 = MyModel().cuda() model2.load_state_dict(state) """ with contextlib.ExitStack() as stack: for m in self.fsdp_modules(): stack.enter_context(m._unflatten_parameters()) yield return
def _register_state_dict_hooks(self): self._register_state_dict_hook(self._post_state_dict_hook) self._register_load_state_dict_pre_hook(self._pre_load_state_dict_hook) def _post_state_dict_hook(self, module, state_dict, prefix, *args, **kwargs): if self.state == FSDPState.UNFLATTEN_PARAM: old_prefix = prefix + self.FSDP_MODULE_NAME + '.' self._replace_state_prefix(state_dict, old_prefix, prefix) return state_dict def _pre_load_state_dict_hook(self, state_dict, prefix, *args, **kwargs): if self.state == FSDPState.UNFLATTEN_PARAM: new_prefix = prefix + self.FSDP_MODULE_NAME + '.' self._replace_state_prefix(state_dict, prefix, new_prefix) @staticmethod def _replace_state_prefix(state_dict, old_prefix, new_prefix): for key in list(state_dict.keys()): if key.startswith(old_prefix): new_key = new_prefix + key[len(old_prefix):] state_dict[new_key] = state_dict.pop(key) ########################################## # backward hooks ########################################## def _register_post_backward_hooks(self): if not torch.is_grad_enabled(): return # NOTE: grad_acc 要保留下来,保证没被销毁,否则 hook 会失效 p = self.flat_param p_tmp = p.expand_as(p) grad_acc = p_tmp.grad_fn.next_functions[0][0] hook = self._make_post_backward_hook(p) handle = grad_acc.register_hook(hook) self.post_backward_hooks.append((grad_acc, handle)) def _make_post_backward_hook(self, param): assert isinstance(param, FlattenedParameter), f'type {type(param)}' def hook(*unsed): # 1. free full parameters # 2. reduce-scatter gradients (TODO: hfreduce.async_reduce_scatter) self.reduce_scatter_stream.wait_stream(torch.cuda.current_stream()) # backward 之后我们就不需要完整的参数了,把它丢掉 param.shard() # prefetch in post backward hook if self.backward_prefetch == BackwardPrefetch.BACKWARD_POST and self.fsdp_index > 0: prev_fsdp = self.fsdp_graph_order[self.fsdp_index - 1] prev_fsdp._rebuild_full_params() with torch.cuda.stream(self.reduce_scatter_stream): # 做 reduce-scatter # 相当于把梯度切分了,进一步减少内存使用 grad = param.grad new_grad = grad.new_empty(param.local_size) # step 1: doing reduce_scatter across fsdp ranks reduce_scatter_base(new_grad, grad, group=self.fsdp_group) # step 2: doing all_reduce across ddp ranks if self.ddp_group is not None: dist.all_reduce(new_grad, group=self.ddp_group) # step 3: average gradients new_grad.div_(self.group.size()) param.grad = new_grad # reduce 完成后再释放 grad 的内存 grad.record_stream(self.reduce_scatter_stream) return hook def _finish_backward(self): for m in self.fsdp_modules(): for _, handle in m.post_backward_hooks: handle.remove() m.post_backward_hooks = [] m.pre_backward_hook_called = False m.state = FSDPState.IDLE self.callback_queued = False self.logger.record_start("reduce_scatter") stream = torch.cuda.current_stream() stream.wait_stream(self.reduce_scatter_stream) self.logger.record_end("reduce_scatter") self.logger.record_end("bwd") if self.is_root: self.logger.print_statistic() self.logger.step() def _register_pre_backward_hooks(self, outputs): if not torch.is_grad_enabled(): return def fn(tensor): tensor.register_hook(self._pre_backward_hook) apply_to_tensors(outputs, fn) def _pre_backward_hook(self, *unsed): # 只需要调用一次就行了 if self.pre_backward_hook_called: return self.pre_backward_hook_called = True if self.is_root: self.logger.record_start("bwd") self.logger.record_start("bwd_allgather") # 如果 forward 之后把参数切分了,backward 之前要先 all-gather if self.reshard_after_forward: self._rebuild_full_params() stream = torch.cuda.current_stream() stream.wait_stream(self.all_gather_stream) if self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE and self.fsdp_index > 0: prev_fsdp = self.fsdp_graph_order[self.fsdp_index - 1] prev_fsdp._rebuild_full_params() self.logger.record_end("bwd_allgather") # backward 的时候梯度会累加到 p.grad 上,要保证 grad 的大小和原始的 p 一样 grad = self.flat_param.grad if grad is not None and grad.size() != self.flat_param.full_data.size(): self.flat_param.grad = None # model 有可能是空的,post_backward_hook 不一定会被调用 # 所以在 pre_backward_hook 里注册 callback if self.is_root and not self.callback_queued: Variable._execution_engine.queue_callback(self._finish_backward) self.callback_queued = True self.state = FSDPState.IN_BACKWARD ########################################## # flatten/unflatten parameters ########################################## def _rebuild_full_params(self): if not self.flat_param.is_sharded: return with torch.cuda.stream(self.all_gather_stream): self.flat_param.unshard() def _flatten_parameters(self): # NOTE: 全部设成 1 的时候应该和 torch 结果一样 p0 = next(self.module.parameters()) BS = self.ALIGN_BITS // p0.element_size() params = set() param_infos = [] # list of (submodule, name, shape, size) for m in self.module.modules(): for n, p in m.named_parameters(recurse=False): if isinstance(p, FlattenedParameter): continue assert p not in params, "Shared parameters are not supported now" assert p.dtype == torch.float32, "Only float32 is supported" assert p.device == p0.device params.add(p) size = align_to(p.numel(), BS) param_infos.append((m, n, p.shape, size)) flat_size = sum(size for m, n, shape, size in param_infos) flat_size = align_to(flat_size, self.fsdp_group.size() * BS) flat_param = p0.data.new_zeros(flat_size) flat_param = FlattenedParameter(flat_param, param_infos, requires_grad=True, group=self.fsdp_group) return flat_param @contextlib.contextmanager def _unflatten_parameters(self): assert self.state == FSDPState.IDLE try: self.state = FSDPState.UNFLATTEN_PARAM self._rebuild_full_params() stream = torch.cuda.current_stream() stream.wait_stream(self.all_gather_stream) self.flat_param.make_views_as_param() self.orig_flat_param = [self.flat_param] del self.flat_param yield finally: self.flat_param = self.orig_flat_param[0] del self.orig_flat_param self.flat_param.make_params_as_view() self.flat_param.shard() self.state = FSDPState.IDLE
def apply_to_tensors(x, fn): if torch.is_tensor(x): return fn(x) if isinstance(x, dict): return {key: apply_to_tensors(value, fn) for key, value in x.items()} if isinstance(x, (list, tuple, set)): return type(x)(apply_to_tensors(el, fn) for el in x) if x is None: return None raise TypeError(f"Unsupported type {type(x)}") def align_to(n, bs): return (n + bs - 1) // bs * bs