Source code for hfai.pl.strategies.hfreduce_spawn_bind_numa
import logging
import pytorch_lightning
from .strategy_utils import bind_numa, check_numa, unwrap_lightning_module_hfai
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.strategies.strategy_registry import StrategyRegistry as Registry
from hfai.nn.parallel import DistributedDataParallel
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
log = logging.getLogger(__name__)
[docs]class HFReduceSpawnStrategyBindNuma(DDPSpawnStrategy):
"""
这是一个可以绑定 numa、使用 hfreduce 的 ddp spawn strategy, 支持 ``1.6.0 <= pytorch_lightning.__version__ <= 1.7.6``
Examples:
.. code-block:: python
from hfai.pl import HFAIEnvironment
trainer = pytorch_lightning.Trainer(
max_epochs=3,
gpus=8,
strategy="hfreduce_spawn_bind_numa", # hfai 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa
plugins=[HFAIEnvironment()] # 定义 Hfai 环境并作为插件输入
)
model_module = ToyNetModule()
trainer.fit(
model_module
)
"""
strategy_name = "hfreduce_spawn_bind_numa"
def _configure_launcher(self) -> None:
if pytorch_lightning.__version__ < '1.7.0':
super()._configure_launcher()
return
# deal with worker output is None
from .launchers.multiprocessing_hf import _MultiProcessingLauncherHF
self._launcher = _MultiProcessingLauncherHF(self, start_method=self._start_method)
def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is None:
return
bind_numa(self.cluster_environment) # add numa bind
assert check_numa(self.cluster_environment) # check if bind success
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
def _setup_model(self, model) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
def _register_ddp_hooks(self) -> None:
assert self._ddp_comm_hook is None, '_ddp_comm_hook in strategy should be None.'
def pre_backward(self, closure_loss) -> None:
pass
@property
def lightning_module(self) -> "pytorch_lightning.LightningModule":
return unwrap_lightning_module_hfai(self._model)
Registry.register(
"hfreduce_spawn_bind_numa", HFReduceSpawnStrategyBindNuma,
description="HFReduce strategy with `start_method` `spawn` and `bind_numa`"
)