Source code for hfai.multiprocessing.spawn
import multiprocessing
from torch.multiprocessing.spawn import ProcessContext, _wrap
from hfai.client import bind_hf_except_hook
from hfai.utils import which_numa
from hfai._C.multiprocessing import numa
def _hf_wrap(fn, i, args, error_queue, bind_numa):
if bind_numa:
numa.bind_numa(which_numa(i))
_wrap(fn, i, args, error_queue)
def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn', bind_numa=True):
mp = multiprocessing.get_context(start_method)
bind_hf_except_hook(mp.Process)
error_queues = []
processes = []
for i in range(nprocs):
error_queue = mp.SimpleQueue()
process = mp.Process(
target=_hf_wrap,
args=(fn, i, args, error_queue, bind_numa),
daemon=daemon,
)
process.start()
error_queues.append(error_queue)
processes.append(process)
context = ProcessContext(processes, error_queues)
if not join:
return context
# Loop on join until it returns True or raises an exception.
while not context.join():
pass
[docs]def spawn(fn, args=(), nprocs=1, join=True, daemon=False, bind_numa=True):
"""
功能和 ``torch.multiprocessing.spawn`` 类似,但支持自动绑定 NUMA
绑定 NUMA 的功能假设第 ``i`` 个进程对应着第 ``i`` 个 GPU。
Args:
bind_numa (bool): 是否绑定 NUMA,默认是 ``True``
Examples:
.. code-block:: python
import torch
import hfai
def main(gpu_id):
torch.cuda.set_device(gpu_id)
# ......
if __name__ == "__main__":
ngpus = torch.cuda.device_count()
hfai.multiprocessing.spawn(main, args=(), nprocs=ngpus, bind_numa=True)
"""
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn', bind_numa=bind_numa)
[docs]def fork(fn, args=(), nprocs=1, join=True, daemon=False, bind_numa=True):
"""
功能和 :func:`spawn` 一样,但是用 fork 的方式启动子进程。
绑定 NUMA 的功能假设第 ``i`` 个进程对应着第 ``i`` 个 GPU。
Args:
bind_numa (bool): 是否绑定 NUMA,默认是 ``True``
Examples:
.. code-block:: python
import torch
import hfai
def main(gpu_id):
torch.cuda.set_device(gpu_id)
# ......
if __name__ == "__main__":
ngpus = hfai.utils.num_gpus() # 调用 cuda 函数会导致子进程产生错误
hfai.multiprocessing.fork(main, args=(), nprocs=ngpus, bind_numa=True)
"""
return start_processes(fn, args, nprocs, join, daemon, start_method="fork", bind_numa=bind_numa)