Source code for

import os
import socket
import pytorch_lightning
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment

if pytorch_lightning.__version__ >= '1.6.0':
    from pytorch_lightning.utilities.rank_zero import rank_zero_only
    from pytorch_lightning.utilities.distributed import rank_zero_only

[docs]class HFAIEnvironment(ClusterEnvironment): """ 这是一个可以自动适配到 Hfai 萤火集群的环境类, 支持 ``1.5.0 <= pytorch_lightning.__version__ <= 1.7.6`` Examples: .. code-block:: python from import HFAIEnvironment trainer = pytorch_lightning.Trainer( max_epochs=3, gpus=8, strategy="ddp_bind_numa", # hfai 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa plugins=[HFAIEnvironment()] # 定义 Hfai 环境并作为插件输入 ) model_module = ToyNetModule() model_module ) """ def __init__(self) -> None: super().__init__() self._main_port: int = -1 self._global_rank: int = 0 self._world_size: int = 1 @property def creates_processes_externally(self) -> bool: """Returns whether the cluster creates the processes or not. If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the process launcher/job scheduler and Lightning will not launch new processes. """ return "LOCAL_RANK" in os.environ @property def main_address(self) -> str: return os.environ.get("MASTER_ADDR", "") def master_port(self) -> int: if self._main_port == -1: self._main_port = int(os.environ.get("MASTER_PORT", find_free_network_port())) return self._main_port def master_address(self) -> str: return os.environ.get("MASTER_ADDR", "") @property def main_port(self) -> int: if self._main_port == -1: self._main_port = int(os.environ.get("MASTER_PORT", find_free_network_port())) return self._main_port @staticmethod def detect() -> bool: return True def world_size(self) -> int: return self._world_size def set_world_size(self, size: int) -> None: self._world_size = size def global_rank(self) -> int: return self._global_rank def set_global_rank(self, rank: int) -> None: self._global_rank = rank rank_zero_only.rank = rank def local_rank(self) -> int: return int(os.environ.get("LOCAL_RANK", 0)) def node_rank(self) -> int: group_rank = os.environ.get("GROUP_RANK", 0) return int(os.environ.get("RANK", group_rank)) def teardown(self) -> None: if "WORLD_SIZE" in os.environ: del os.environ["WORLD_SIZE"]
def find_free_network_port() -> int: """Finds a free port on localhost. It is useful in single-node training when we don't want to connect to a real main node but have to set the `MASTER_PORT` environment variable. """ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) port = s.getsockname()[1] s.close() return port