import contextlib
from enum import Enum, auto
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from .param import FlattenedParameter
from .communication import reduce_scatter_base
from .logger import FSDPLogger
# TODO: 检查这些假设
# 我们假设是
# [x] 1. fp32
# [x] 2. 没有共享 parameter
# [x] 3. 只做一次 forward
# TODO
# - [ ] 支持 gradient accumulation
class FSDPState(Enum):
IDLE = auto() # forward 之前,backward 之后
IN_FORWARD = auto() # 正在 forward
POST_FORWARD = auto() # forward 之后,backward 之前
IN_BACKWARD = auto() # 正在 backward
PRE_BACKWARD = auto() # 调用 pre_backward_hook
POST_BACKWARD = auto() # 调用 post_backward_hook
UNFLATTEN_PARAM = auto()
class BackwardPrefetch(Enum):
BACKWARD_PRE = auto()
BACKWARD_POST = auto()
[docs]class FullyShardedDataParallel(nn.Module):
"""
FullyShardedDataParallel,使用方法和 PyTorch FSDP 类似。
Args:
module (torch.nn.Module): 用户的 PyTorch 模型
process_group (ProcessGroup): 全局的 process group,默认是 ``None``
auto_wrap_policy (Callable[nn.Module, bool, int]):
``auto_wrap_policy`` 决定了参数如何分组,和 PyTorch FSDP 的同名参数作用一样;默认是 ``None``。
reshard_after_forward (bool): 是否在 forward 之后切分参数;默认是 ``True``
backward_prefetch (BackwardPrefetch):
在 backward 的时候如何 prefetch 下一个 FSDP 的参数,如果是 ``None`` 则不会做 prefetch;默认是 ``None``。
``BackwardPrefetch.BACKWARD_PRE``: 在 backward 之前 prefetch 下一个 FSDP 的参数;
``BackwardPrefetch.BACKWARD_POST``: 在 backward 之后 prefetch 下一个 FSDP 的参数。
forward_prefetch (bool): 在 forward 的时候是否 prefetch 下一个 FSDP 的参数,默认是 ``False``
fsdp_group_size (int): FSDP 一个分组的大小(all_gather 的 rank 数量),FSDP 不同分组之间做数据并行;
如果是 None 则全部 rank 都在同一个 FSDP 分组;默认是 ``None``
Examples:
训练:
.. code-block:: python
from haiscale.fsdp import FullyShardedDataParallel
model = MyModel().cuda()
model = FullyShardedDataParallel(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
# training ...
for epoch in range(epochs):
for step, (x, y) in enumerate(dataloader):
# training
optimizer.zero_grad()
output = model(x)
loss_fn(y, output).backward()
optimizer.step()
保存完整的 checkpoint:
.. code-block:: python
model = FullyShardedDataParallel(model)
with model.summon_full_params():
if rank == 0:
state = model.state_dict()
torch.save(state, 'model.pt')
# FSDP 模型加载完整的 checkpoint
state = torch.load('model.pt', map_location='cpu')
with model.summon_full_params():
model.load_state_dict(state)
# 非 FSDP 模型也可以直接加载
model2 = MyModel().cuda()
model2.load_state_dict(state)
每个 GPU 都保存自己本地的 checkpoint(保存速度更快):
.. code-block:: python
# 每个 GPU 上的 FSDP 都保存一份 checkpoint
model = FullyShardedDataParallel(model)
state = model.state_dict()
rank = dist.get_rank()
torch.save(state, f'model{rank}.pt')
# 每个 GPU 上的 FSDP 都需要加载
state = torch.load(f'model{rank}.pt', map_location='cpu')
model.load_state_dict(state)
"""
FSDP_MODULE_NAME = '_fsdp_orig_module'
ALIGN_BITS = 128
def __init__(
self,
module,
process_group=None,
auto_wrap_policy=None,
reshard_after_forward=True,
backward_prefetch=None,
forward_prefetch=False,
fsdp_group_size=None,
_is_root=True,
_fsdp_group=None,
_ddp_group=None,
):
super().__init__()
assert isinstance(module, nn.Module)
self.add_module(self.FSDP_MODULE_NAME, module)
self.group = process_group or _get_default_group()
if _fsdp_group is not None:
self.fsdp_group = _fsdp_group
self.ddp_group = _ddp_group
else:
assert fsdp_group_size is None or fsdp_group_size <= self.group.size()
if fsdp_group_size is not None and fsdp_group_size < self.group.size():
assert self.group is _get_default_group()
self.fsdp_group = self._create_fsdp_group(self.group, fsdp_group_size)
self.ddp_group = self._create_ddp_group(self.group, fsdp_group_size)
else:
self.fsdp_group = self.group
self.ddp_group = None
# broadcast model from rank-0
self.is_root = _is_root
if self.is_root:
self._sync_params_and_buffers()
if auto_wrap_policy is not None:
try:
from torch.distributed.fsdp.wrap import _recursive_wrap
except ModuleNotFoundError as e:
raise RuntimeError("auto_wrap_policy requires torch >= 1.12")
_recursive_wrap(
module, auto_wrap_policy,
FullyShardedDataParallel,
ignored_modules=set(),
ignored_params=set(),
only_wrap_children=True,
# follow are FSDP kwargs
process_group=process_group,
reshard_after_forward=reshard_after_forward,
forward_prefetch=forward_prefetch,
backward_prefetch=backward_prefetch,
fsdp_group_size=fsdp_group_size,
_is_root=False,
_fsdp_group=self.fsdp_group,
_ddp_group=self.ddp_group,
)
self.flat_param = self._flatten_parameters()
self.flat_param.make_params_as_view()
# 切分参数
self.flat_param.shard()
self.reshard_after_forward = reshard_after_forward and (not self.is_root)
self.post_backward_hooks = []
self.callback_queued = False
self.pre_backward_hook_called = False
self.forward_prefetch = forward_prefetch
self.backward_prefetch = backward_prefetch
self.fsdp_index = None
self.state = FSDPState.IDLE
# hooks for save/load state dict
self._register_state_dict_hooks()
if self.is_root:
# 一些共享的变量
self.all_gather_stream = torch.cuda.Stream()
self.reduce_scatter_stream = torch.cuda.Stream()
self.fsdp_graph_order = []
self.logger = FSDPLogger(enable=False)
for m in self.fsdp_modules():
m.fsdp_graph_order = self.fsdp_graph_order
m.all_gather_stream = self.all_gather_stream
m.reduce_scatter_stream = self.reduce_scatter_stream
m.logger = self.logger
def forward(self, *args, **kwargs):
assert self.state == FSDPState.IDLE
if self.is_root:
self.logger.record_start("fwd")
self.logger.record_start("pre_fwd")
# 这里假设每次 forward 的顺序都是一样的
if self not in self.fsdp_graph_order:
self.fsdp_index = len(self.fsdp_graph_order)
self.fsdp_graph_order.append(self)
if self.is_root:
assert self.fsdp_index == 0, self.fsdp_index
# 等待 optimizer.step() 完成
stream = torch.cuda.current_stream()
if self.fsdp_index == 0:
self.all_gather_stream.wait_stream(stream)
# all-gather
self.logger.record_start("fwd_allgather")
self._rebuild_full_params()
stream.wait_stream(self.all_gather_stream)
if self.forward_prefetch and self.fsdp_index != len(self.fsdp_graph_order) - 1:
next_fsdp = self.fsdp_graph_order[self.fsdp_index + 1]
next_fsdp._rebuild_full_params()
self.logger.record_end("fwd_allgather")
# torch.split 之后如果有 in-place 的操作(比如 optimizer.step),原来的 view 就用不了了
# 所以我们重新 make view 一下
self.flat_param.make_params_as_view()
self.logger.record_end("pre_fwd")
# forward
outputs = self.module(*args, **kwargs)
self.logger.record_start("post_fwd")
# 通过 all-gather 拿到完整数据之后我们才能调用 p.expand_as
# 所以只能在 forward 的时候注册钩子,不能在构造函数里注册
self._register_post_backward_hooks()
if self.reshard_after_forward:
self.flat_param.shard()
# 实际上并没有 pre-backward hook,只能在输出的 tensor 里加 hook
self._register_pre_backward_hooks(outputs)
self.state = FSDPState.POST_FORWARD
if not torch.is_grad_enabled():
self.state = FSDPState.IDLE
self.logger.record_end("post_fwd")
if self.is_root:
self.logger.record_end("fwd")
return outputs
##########################################
# utils
##########################################
def fsdp_modules(self):
for m in self.modules():
if isinstance(m, FullyShardedDataParallel):
yield m
@property
def module(self):
return getattr(self, self.FSDP_MODULE_NAME)
def enable_logger(self, enable=True):
self.logger.enable = enable
def _sync_params_and_buffers(self):
bucket_size = 250 * (1 << 20) # 250 MB
module_states = list(self.module.parameters()) + list(self.module.buffers())
module_states = [p.data for p in module_states]
dist._broadcast_coalesced(self.group, module_states, bucket_size, 0)
def extra_repr(self) -> str:
flat_param_size = self.flat_param.numel()
return f'flattened parameter size: {flat_param_size}\n' \
f'root: {self.is_root}\n' \
f'fsdp_index: {self.fsdp_index}'
def _create_ddp_group(self, group, fsdp_group_size):
assert group.size() % fsdp_group_size == 0
ddp_group_size = group.size() // fsdp_group_size
cur_fsdp_rank = group.rank() % fsdp_group_size
cur_ddp_group = None
for fsdp_rank in range(fsdp_group_size):
ranks = [fsdp_rank + j * fsdp_group_size for j in range(ddp_group_size)]
ddp_group = dist.new_group(ranks)
if fsdp_rank == cur_fsdp_rank:
cur_ddp_group = ddp_group
assert cur_ddp_group is not None
return cur_ddp_group
def _create_fsdp_group(self, group, fsdp_group_size):
assert group.size() % fsdp_group_size == 0
num_fsdp_groups = group.size() // fsdp_group_size
cur_fsdp_group = None
for i in range(num_fsdp_groups):
ranks = list(range(i * fsdp_group_size, (i + 1) * fsdp_group_size))
fsdp_group = dist.new_group(ranks)
if group.rank() // fsdp_group_size == i:
cur_fsdp_group = fsdp_group
assert cur_fsdp_group is not None
return cur_fsdp_group
##########################################
# save/load state_dict
##########################################
[docs] @contextlib.contextmanager
def summon_full_params(self):
"""
all_gather 完整的参数,方便用户保存完整的 checkpoint。
保存后的 checkpoint 等价于 `fsdp_model.module.state_dict()`
Examples:
.. code-block:: python
from haiscale.fsdp import FullyShardedDataParallel
model = MyModel().cuda()
model = FullyShardedDataParallel(model)
# save checkpoint
with model.summon_full_params():
if rank == 0:
state = model.state_dict()
torch.save(state, 'model.pt')
# load checkpoint
state = torch.load('model.pt', map_location='cpu')
with model.summon_full_params():
model.load_state_dict(state)
# this could also work
model2 = MyModel().cuda()
model2.load_state_dict(state)
"""
with contextlib.ExitStack() as stack:
for m in self.fsdp_modules():
stack.enter_context(m._unflatten_parameters())
yield
return
def _register_state_dict_hooks(self):
self._register_state_dict_hook(self._post_state_dict_hook)
self._register_load_state_dict_pre_hook(self._pre_load_state_dict_hook)
def _post_state_dict_hook(self, module, state_dict, prefix, *args, **kwargs):
if self.state == FSDPState.UNFLATTEN_PARAM:
old_prefix = prefix + self.FSDP_MODULE_NAME + '.'
self._replace_state_prefix(state_dict, old_prefix, prefix)
return state_dict
def _pre_load_state_dict_hook(self, state_dict, prefix, *args, **kwargs):
if self.state == FSDPState.UNFLATTEN_PARAM:
new_prefix = prefix + self.FSDP_MODULE_NAME + '.'
self._replace_state_prefix(state_dict, prefix, new_prefix)
@staticmethod
def _replace_state_prefix(state_dict, old_prefix, new_prefix):
for key in list(state_dict.keys()):
if key.startswith(old_prefix):
new_key = new_prefix + key[len(old_prefix):]
state_dict[new_key] = state_dict.pop(key)
##########################################
# backward hooks
##########################################
def _register_post_backward_hooks(self):
if not torch.is_grad_enabled():
return
# NOTE: grad_acc 要保留下来,保证没被销毁,否则 hook 会失效
p = self.flat_param
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
hook = self._make_post_backward_hook(p)
handle = grad_acc.register_hook(hook)
self.post_backward_hooks.append((grad_acc, handle))
def _make_post_backward_hook(self, param):
assert isinstance(param, FlattenedParameter), f'type {type(param)}'
def hook(*unsed):
# 1. free full parameters
# 2. reduce-scatter gradients (TODO: hfreduce.async_reduce_scatter)
self.reduce_scatter_stream.wait_stream(torch.cuda.current_stream())
# backward 之后我们就不需要完整的参数了,把它丢掉
param.shard()
# prefetch in post backward hook
if self.backward_prefetch == BackwardPrefetch.BACKWARD_POST and self.fsdp_index > 0:
prev_fsdp = self.fsdp_graph_order[self.fsdp_index - 1]
prev_fsdp._rebuild_full_params()
with torch.cuda.stream(self.reduce_scatter_stream):
# 做 reduce-scatter
# 相当于把梯度切分了,进一步减少内存使用
grad = param.grad
new_grad = grad.new_empty(param.local_size)
# step 1: doing reduce_scatter across fsdp ranks
reduce_scatter_base(new_grad, grad, group=self.fsdp_group)
# step 2: doing all_reduce across ddp ranks
if self.ddp_group is not None:
dist.all_reduce(new_grad, group=self.ddp_group)
# step 3: average gradients
new_grad.div_(self.group.size())
param.grad = new_grad
# reduce 完成后再释放 grad 的内存
grad.record_stream(self.reduce_scatter_stream)
return hook
def _finish_backward(self):
for m in self.fsdp_modules():
for _, handle in m.post_backward_hooks:
handle.remove()
m.post_backward_hooks = []
m.pre_backward_hook_called = False
m.state = FSDPState.IDLE
self.callback_queued = False
self.logger.record_start("reduce_scatter")
stream = torch.cuda.current_stream()
stream.wait_stream(self.reduce_scatter_stream)
self.logger.record_end("reduce_scatter")
self.logger.record_end("bwd")
if self.is_root:
self.logger.print_statistic()
self.logger.step()
def _register_pre_backward_hooks(self, outputs):
if not torch.is_grad_enabled():
return
def fn(tensor):
tensor.register_hook(self._pre_backward_hook)
apply_to_tensors(outputs, fn)
def _pre_backward_hook(self, *unsed):
# 只需要调用一次就行了
if self.pre_backward_hook_called:
return
self.pre_backward_hook_called = True
if self.is_root:
self.logger.record_start("bwd")
self.logger.record_start("bwd_allgather")
# 如果 forward 之后把参数切分了,backward 之前要先 all-gather
if self.reshard_after_forward:
self._rebuild_full_params()
stream = torch.cuda.current_stream()
stream.wait_stream(self.all_gather_stream)
if self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE and self.fsdp_index > 0:
prev_fsdp = self.fsdp_graph_order[self.fsdp_index - 1]
prev_fsdp._rebuild_full_params()
self.logger.record_end("bwd_allgather")
# backward 的时候梯度会累加到 p.grad 上,要保证 grad 的大小和原始的 p 一样
grad = self.flat_param.grad
if grad is not None and grad.size() != self.flat_param.full_data.size():
self.flat_param.grad = None
# model 有可能是空的,post_backward_hook 不一定会被调用
# 所以在 pre_backward_hook 里注册 callback
if self.is_root and not self.callback_queued:
Variable._execution_engine.queue_callback(self._finish_backward)
self.callback_queued = True
self.state = FSDPState.IN_BACKWARD
##########################################
# flatten/unflatten parameters
##########################################
def _rebuild_full_params(self):
if not self.flat_param.is_sharded:
return
with torch.cuda.stream(self.all_gather_stream):
self.flat_param.unshard()
def _flatten_parameters(self):
# NOTE: 全部设成 1 的时候应该和 torch 结果一样
p0 = next(self.module.parameters())
BS = self.ALIGN_BITS // p0.element_size()
params = set()
param_infos = [] # list of (submodule, name, shape, size)
for m in self.module.modules():
for n, p in m.named_parameters(recurse=False):
if isinstance(p, FlattenedParameter):
continue
assert p not in params, "Shared parameters are not supported now"
assert p.dtype == torch.float32, "Only float32 is supported"
assert p.device == p0.device
params.add(p)
size = align_to(p.numel(), BS)
param_infos.append((m, n, p.shape, size))
flat_size = sum(size for m, n, shape, size in param_infos)
flat_size = align_to(flat_size, self.fsdp_group.size() * BS)
flat_param = p0.data.new_zeros(flat_size)
flat_param = FlattenedParameter(flat_param, param_infos, requires_grad=True, group=self.fsdp_group)
return flat_param
@contextlib.contextmanager
def _unflatten_parameters(self):
assert self.state == FSDPState.IDLE
try:
self.state = FSDPState.UNFLATTEN_PARAM
self._rebuild_full_params()
stream = torch.cuda.current_stream()
stream.wait_stream(self.all_gather_stream)
self.flat_param.make_views_as_param()
self.orig_flat_param = [self.flat_param]
del self.flat_param
yield
finally:
self.flat_param = self.orig_flat_param[0]
del self.orig_flat_param
self.flat_param.make_params_as_view()
self.flat_param.shard()
self.state = FSDPState.IDLE
def apply_to_tensors(x, fn):
if torch.is_tensor(x):
return fn(x)
if isinstance(x, dict):
return {key: apply_to_tensors(value, fn) for key, value in x.items()}
if isinstance(x, (list, tuple, set)):
return type(x)(apply_to_tensors(el, fn) for el in x)
if x is None:
return None
raise TypeError(f"Unsupported type {type(x)}")
def align_to(n, bs):
return (n + bs - 1) // bs * bs