Shortcuts

Source code for hfai.pl.strategies.hfreduce_bind_numa

import logging
from .strategy_utils import bind_numa, check_numa, unwrap_lightning_module_hfai
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies.strategy_registry import StrategyRegistry as Registry
from hfai.nn.parallel import DistributedDataParallel

log = logging.getLogger(__name__)


[docs]class HFReduceStrategyBindNuma(DDPStrategy): """ 这是一个可以绑定 numa、使用 hfreduce 的 ddp strategy, 支持 ``1.6.0 <= pytorch_lightning.__version__ <= 1.7.6`` Examples: .. code-block:: python from hfai.pl import HFAIEnvironment trainer = pytorch_lightning.Trainer( max_epochs=3, gpus=8, strategy="hfreduce_bind_numa", # hfai 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa plugins=[HFAIEnvironment()] # 定义 Hfai 环境并作为插件输入 ) model_module = ToyNetModule() trainer.fit( model_module ) """ strategy_name = "hfreduce_bind_numa" def set_world_ranks(self) -> None: if self.cluster_environment is None: return bind_numa(self.cluster_environment) # add numa bind assert check_numa(self.cluster_environment) # check if bind success self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) rank_zero_only.rank = self.cluster_environment.global_rank() def _setup_model(self, model): """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() print(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) @property def lightning_module(self) -> "pytorch_lightning.LightningModule": return unwrap_lightning_module_hfai(self._model)
Registry.register("hfreduce_bind_numa", HFReduceStrategyBindNuma, description="HFReduce strategy with `bind_numa`")