import datetime
import functools
import pickle
try:
import torch
import torch.distributed.distributed_c10d as tdist
except ImportError as e:
torch = None
tdist = None
import zmq
import threading
from zmq.utils.monitor import recv_monitor_message # requires libzmq >= 4.0
HF_ZMQ_BACKEND = 'hf_zmq'
HF_ZMQ_LOCAL_BACKEND = 'hf_zmq_local' # not to override globally inited process group, for library use
HF_BARRIER_BACKEND = HF_ZMQ_BACKEND # for backward compatibility
_backend = None
_zmq_pg = None
class ZmqServer(object):
def __init__(self, host, port, world_size):
self.world_size = world_size
self.zc = zmq.Context()
self.pub_s = self.zc.socket(zmq.PUB)
self.pub_s.bind(f"tcp://*:{port - 2}")
self.router_s = self.zc.socket(zmq.ROUTER)
self.router_s.bind(f"tcp://*:{port - 1}")
self.pull_s = self.zc.socket(zmq.PULL)
self.pull_s.bind(f"tcp://*:{port}")
self.calling_method = None
self.called_ranks = {}
self.sn = 0
self.handlers = {method: getattr(self, '_handle_' + method) for method in ['barrier', 'gather', 'broadcast']}
self.th = threading.Thread(target=self._run)
self.th.setDaemon(True)
self.th.start()
def _run(self):
self._wait_for_all_clients(self.pub_s)
self.router_clients = self._collect_all_clients(self.router_s)
while True:
req = self.pull_s.recv_pyobj()
if req['sn'] != self.sn:
continue
method = req['method']
if self.calling_method is not None and method != self.calling_method:
continue
self.calling_method = method
self.called_ranks[req['rank']] = req['payload']
if len(self.called_ranks) != self.world_size:
continue
self.handlers[method](req)
self._finish_op()
def _handle_barrier(self, req):
self._send_resp('barrier')
def _handle_gather(self, req):
dst = req['dst']
self._send_resp('gather', payload=self.called_ranks, to=dst, dst=dst)
self._send_resp('gather')
def _handle_broadcast(self, req):
src = req['src']
if self.called_ranks[src] is None or len(self.called_ranks[src]) != self.world_size:
self.called_ranks[src] = None
self._send_resp('broadcast', payload=self.called_ranks[src], src=src)
def _send_resp(self, method, payload=None, to=None, **extra):
resp = {'method': method, 'sn': self.sn, 'payload': payload}
resp.update(extra)
if to is None:
self.pub_s.send_pyobj(resp)
else:
self.router_s.send_multipart(self.router_clients[to] + [pickle.dumps(resp)])
def _finish_op(self):
self.calling_method = None
self.called_ranks = {}
self.sn += 1
def _wait_for_all_clients(self, sock):
clients = 0
events_socket = sock.get_monitor_socket(events=zmq.EVENT_HANDSHAKE_SUCCEEDED) # only accept this event
while clients < self.world_size:
recv_monitor_message(events_socket) # this will block until a handshake was successful
clients += 1
def _collect_all_clients(self, sock):
clients = {}
while len(clients) < self.world_size:
msg = sock.recv_multipart()
payload = pickle.loads(msg[-1])
if payload['method'] != 'conn':
continue
rank = payload['rank']
clients[rank] = msg[:-1]
return clients
class ZmqProcessGroup(object):
def __init__(self, init_method, world_size, rank):
host, port = init_method[len('tcp://'):].split(':')
port = int(port)
self.world_size = world_size
self.rank = rank
self.sn = 0
if rank == 0:
self.server = ZmqServer(host, port, world_size)
self.zc = zmq.Context()
self.sub_s = self.zc.socket(zmq.SUB)
self.sub_s.connect(f'tcp://{host}:{port - 2}')
self.sub_s.subscribe('')
self.req_s = self.zc.socket(zmq.REQ)
self.req_s.connect(f'tcp://{host}:{port - 1}')
self.req_s.send_pyobj({'method': 'conn', 'rank': rank})
self.push_s = self.zc.socket(zmq.PUSH)
self.push_s.connect(f'tcp://{host}:{port}')
self.barrier()
def get_rank(self):
return self.rank
def get_world_size(self):
return self.world_size
def barrier(self, *a, **k):
self._req_server('barrier')
def gather(self, obj, gather_list=None, dst=0, *a, **k):
if self.rank == dst:
assert gather_list is not None and len(gather_list) == self.world_size
resp = self._req_server('gather', payload=obj, sock=self.req_s, dst=dst)
self.sn -= 1
assert resp['dst'] == self.rank
for r, o in resp['payload'].items():
gather_list[r] = o
self._check_resp('gather')
else:
self._req_server('gather', payload=obj, dst=dst)
def broadcast(self, objs, src=0, *a, **k):
assert len(objs) == 1
# update the objs list for src rank too, to make it a clone of the original object
resp = self._req_server('broadcast', payload=objs[0] if self.rank == src else None, src=src)
obj = resp['payload']
assert obj is not None
objs[0] = obj
def _req_server(self, method, payload=None, sock=None, **extra):
self._send_req(method, payload=payload, **extra)
return self._check_resp(method, sock)
def _check_resp(self, method, sock=None):
if sock is None:
sock = self.sub_s
while True:
resp = sock.recv_pyobj()
if resp:
assert resp['method'] == method and resp['sn'] == self.sn, (resp, self.sn)
self.sn += 1
return resp
def _send_req(self, method, **extra):
req = {'method': method, 'rank': self.rank, 'sn': self.sn}
req.update(extra)
self.push_s.send_pyobj(req)
[docs]def init_process_group(backend,
init_method=None,
timeout=datetime.timedelta(seconds=1800),
world_size=-1,
rank=-1,
store=None,
group_name='',
pg_options=None):
"""
功能与 ``torch.distributed.init_process_group`` 类似
基于 `torch.distributed.init_process_group <https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group>`_ 提供的 ``init_process_group``,支持 backend 为 ``hf_barrier``,通过 zmq 的方式作 barrier,使用方式与 ``torch.distributed.init_process_group`` 保持一致
Examples:
.. code-block:: python
import hfai.distributed as dist
dist.init_process_group(backend=dist.HF_ZMQ_BACKEND)
"""
global _backend
if backend == HF_ZMQ_BACKEND or backend == HF_ZMQ_LOCAL_BACKEND:
_backend = backend
pg = ZmqProcessGroup(init_method, world_size, rank)
if backend == HF_ZMQ_LOCAL_BACKEND:
return pg
else:
global _zmq_pg
_zmq_pg = pg
else:
_backend = tdist.Backend(backend)
tdist.init_process_group(
backend,
init_method=init_method,
timeout=timeout,
world_size=world_size,
rank=rank,
store=store,
group_name=group_name,
)
def _check_backend(f):
fname = f.__name__
@functools.wraps(f)
def g(*args, **kwargs):
global _backend
if _backend == HF_ZMQ_BACKEND:
global _zmq_pg
return getattr(_zmq_pg, fname)(*args, **kwargs)
else:
return f(*args, **kwargs)
return g
def is_initialized():
if _backend == HF_ZMQ_BACKEND:
return _zmq_pg is not None
else:
return tdist.is_initialized()
@_check_backend
def get_rank(group=None):
group = group or tdist.GroupMember.WORLD
return tdist.get_rank(group=group)
@_check_backend
def get_world_size(group=None):
group = group or tdist.GroupMember.WORLD
return tdist.get_world_size(group=group)
[docs]@_check_backend
def barrier(group=None, async_op=False):
"""
功能与 ``torch.distributed.barrier`` 类似
基于 zmq 提供的 barrier, 支持 backend 为 ``hf_zmq``,参数与与 `torch.distributed.barrier <https://pytorch.org/docs/stable/distributed.html?highlight=torch%20distributed%20barrier#torch.distributed.barrier>`_ 保持一致
Examples:
.. code-block:: python
import hfai.distributed as dist
dist.barrier()
"""
group = group or tdist.GroupMember.WORLD
tdist.barrier(group, async_op)
[docs]@_check_backend
def gather(obj, gather_list=None, dst=0, group=None, async_op=False):
"""
功能与 ``torch.distributed.gather`` 类似
基于 zmq 提供的 gather,与 `torch 提供的 gather <https://pytorch.org/docs/stable/distributed.html?highlight=gather#torch.distributed.gather>`_ 相比暂不支持 async_op 与 group 参数,其余参数保持一致
"""
assert group is None
assert not async_op, '暂未支持异步调用gather'
world_size = tdist.get_world_size()
rank = tdist.get_rank()
if rank == dst:
assert gather_list is not None
gather_list[rank] = torch.clone(obj)
work_list = []
for i in reversed(range(world_size)):
if i != rank:
work_list.append(tdist.irecv(gather_list[i], src=i))
[work.wait() for work in work_list]
else:
assert gather_list is None
work = tdist.isend(obj, dst=dst)
work.wait()
tdist.barrier()
[docs]@_check_backend
def broadcast(objs, src=0, group=None, async_op=False):
"""
基于 zmq 提供的 broadcast,与 torch 提供的 broadcast 相比暂不支持 async_op 与 group 参数
Args:
objs (list[object]): 只有一项的list,对于 rank 为 src 的调用者,该项应为输入的 python object,对于其他调用者,会将该项置为输入的 object
src (int): 源 rank
Returns:
None
Examples:
.. code-block:: python
import hfai.distributed as dist
objs = ['to broadcast']
dist.broadcast(objs, src=0)
"""
group = group or tdist.GroupMember.WORLD
tdist.broadcast(objs, src, group=group, async_op=async_op)