Source code for hfai.distributed.zmq_backend

import datetime
import functools
import pickle

    import torch
    import torch.distributed.distributed_c10d as tdist
except ImportError as e:
    torch = None
    tdist = None

import zmq
import threading
from zmq.utils.monitor import recv_monitor_message  # requires libzmq >= 4.0

HF_ZMQ_BACKEND = 'hf_zmq'
HF_ZMQ_LOCAL_BACKEND = 'hf_zmq_local' # not to override globally inited process group, for library use
HF_BARRIER_BACKEND = HF_ZMQ_BACKEND # for backward compatibility

_backend = None
_zmq_pg = None

class ZmqServer(object):
    def __init__(self, host, port, world_size):
        self.world_size = world_size

        self.zc = zmq.Context()
        self.pub_s = self.zc.socket(zmq.PUB)
        self.pub_s.bind(f"tcp://*:{port - 2}")

        self.router_s = self.zc.socket(zmq.ROUTER)
        self.router_s.bind(f"tcp://*:{port - 1}")

        self.pull_s = self.zc.socket(zmq.PULL)

        self.calling_method = None
        self.called_ranks = {} = 0

        self.handlers = {method: getattr(self, '_handle_' + method) for method in ['barrier', 'gather', 'broadcast']} = threading.Thread(target=self._run)

    def _run(self):
        self.router_clients = self._collect_all_clients(self.router_s)

        while True:
            req = self.pull_s.recv_pyobj()
            if req['sn'] !=

            method = req['method']
            if self.calling_method is not None and method != self.calling_method:

            self.calling_method = method
            self.called_ranks[req['rank']] = req['payload']
            if len(self.called_ranks) != self.world_size:


    def _handle_barrier(self, req):

    def _handle_gather(self, req):
        dst = req['dst']
        self._send_resp('gather', payload=self.called_ranks, to=dst, dst=dst)

    def _handle_broadcast(self, req):
        src = req['src']
        if self.called_ranks[src] is None or len(self.called_ranks[src]) != self.world_size:
            self.called_ranks[src] = None
        self._send_resp('broadcast', payload=self.called_ranks[src], src=src)

    def _send_resp(self, method, payload=None, to=None, **extra):
        resp = {'method': method, 'sn':, 'payload': payload}

        if to is None:
            self.router_s.send_multipart(self.router_clients[to] + [pickle.dumps(resp)])

    def _finish_op(self):
        self.calling_method = None
        self.called_ranks = {} += 1

    def _wait_for_all_clients(self, sock):
        clients = 0
        events_socket = sock.get_monitor_socket(events=zmq.EVENT_HANDSHAKE_SUCCEEDED)  # only accept this event
        while clients < self.world_size:
            recv_monitor_message(events_socket)  # this will block until a handshake was successful
            clients += 1

    def _collect_all_clients(self, sock):
        clients = {}
        while len(clients) < self.world_size:
            msg = sock.recv_multipart()
            payload = pickle.loads(msg[-1])
            if payload['method'] != 'conn':
            rank = payload['rank']
            clients[rank] = msg[:-1]
        return clients

class ZmqProcessGroup(object):
    def __init__(self, init_method, world_size, rank):
        host, port = init_method[len('tcp://'):].split(':')
        port = int(port)

        self.world_size = world_size
        self.rank = rank = 0

        if rank == 0:
            self.server = ZmqServer(host, port, world_size)

        self.zc = zmq.Context()
        self.sub_s = self.zc.socket(zmq.SUB)
        self.sub_s.connect(f'tcp://{host}:{port - 2}')

        self.req_s = self.zc.socket(zmq.REQ)
        self.req_s.connect(f'tcp://{host}:{port - 1}')
        self.req_s.send_pyobj({'method': 'conn', 'rank': rank})

        self.push_s = self.zc.socket(zmq.PUSH)


    def get_rank(self):
        return self.rank

    def get_world_size(self):
        return self.world_size

    def barrier(self, *a, **k):

    def gather(self, obj, gather_list=None, dst=0, *a, **k):
        if self.rank == dst:
            assert gather_list is not None and len(gather_list) == self.world_size

            resp = self._req_server('gather', payload=obj, sock=self.req_s, dst=dst)
   -= 1
            assert resp['dst'] == self.rank

            for r, o in resp['payload'].items():
                gather_list[r] = o

            self._req_server('gather', payload=obj, dst=dst)

    def broadcast(self, objs, src=0, *a, **k):
        assert len(objs) == 1

        # update the objs list for src rank too, to make it a clone of the original object
        resp = self._req_server('broadcast', payload=objs[0] if self.rank == src else None, src=src)
        obj = resp['payload']
        assert obj is not None
        objs[0] = obj

    def _req_server(self, method, payload=None, sock=None, **extra):
        self._send_req(method, payload=payload, **extra)
        return self._check_resp(method, sock)

    def _check_resp(self, method, sock=None):
        if sock is None:
            sock = self.sub_s

        while True:
            resp = sock.recv_pyobj()
            if resp:
                assert resp['method'] == method and resp['sn'] ==, (resp,
       += 1
                return resp

    def _send_req(self, method, **extra):
        req = {'method': method, 'rank': self.rank, 'sn':}

[docs]def init_process_group(backend, init_method=None, timeout=datetime.timedelta(seconds=1800), world_size=-1, rank=-1, store=None, group_name='', pg_options=None): """ 功能与 ``torch.distributed.init_process_group`` 类似 基于 `torch.distributed.init_process_group <>`_ 提供的 ``init_process_group``,支持 backend 为 ``hf_barrier``,通过 zmq 的方式作 barrier,使用方式与 ``torch.distributed.init_process_group`` 保持一致 Examples: .. code-block:: python import hfai.distributed as dist dist.init_process_group(backend=dist.HF_ZMQ_BACKEND) """ global _backend if backend == HF_ZMQ_BACKEND or backend == HF_ZMQ_LOCAL_BACKEND: _backend = backend pg = ZmqProcessGroup(init_method, world_size, rank) if backend == HF_ZMQ_LOCAL_BACKEND: return pg else: global _zmq_pg _zmq_pg = pg else: _backend = tdist.Backend(backend) tdist.init_process_group( backend, init_method=init_method, timeout=timeout, world_size=world_size, rank=rank, store=store, group_name=group_name, )
def _check_backend(f): fname = f.__name__ @functools.wraps(f) def g(*args, **kwargs): global _backend if _backend == HF_ZMQ_BACKEND: global _zmq_pg return getattr(_zmq_pg, fname)(*args, **kwargs) else: return f(*args, **kwargs) return g def is_initialized(): if _backend == HF_ZMQ_BACKEND: return _zmq_pg is not None else: return tdist.is_initialized() @_check_backend def get_rank(group=None): group = group or tdist.GroupMember.WORLD return tdist.get_rank(group=group) @_check_backend def get_world_size(group=None): group = group or tdist.GroupMember.WORLD return tdist.get_world_size(group=group)
[docs]@_check_backend def barrier(group=None, async_op=False): """ 功能与 ``torch.distributed.barrier`` 类似 基于 zmq 提供的 barrier, 支持 backend 为 ``hf_zmq``,参数与与 `torch.distributed.barrier <>`_ 保持一致 Examples: .. code-block:: python import hfai.distributed as dist dist.barrier() """ group = group or tdist.GroupMember.WORLD tdist.barrier(group, async_op)
[docs]@_check_backend def gather(obj, gather_list=None, dst=0, group=None, async_op=False): """ 功能与 ``torch.distributed.gather`` 类似 基于 zmq 提供的 gather,与 `torch 提供的 gather <>`_ 相比暂不支持 async_op 与 group 参数,其余参数保持一致 """ assert group is None assert not async_op, '暂未支持异步调用gather' world_size = tdist.get_world_size() rank = tdist.get_rank() if rank == dst: assert gather_list is not None gather_list[rank] = torch.clone(obj) work_list = [] for i in reversed(range(world_size)): if i != rank: work_list.append(tdist.irecv(gather_list[i], src=i)) [work.wait() for work in work_list] else: assert gather_list is None work = tdist.isend(obj, dst=dst) work.wait() tdist.barrier()
[docs]@_check_backend def broadcast(objs, src=0, group=None, async_op=False): """ 基于 zmq 提供的 broadcast,与 torch 提供的 broadcast 相比暂不支持 async_op 与 group 参数 Args: objs (list[object]): 只有一项的list,对于 rank 为 src 的调用者,该项应为输入的 python object,对于其他调用者,会将该项置为输入的 object src (int): 源 rank Returns: None Examples: .. code-block:: python import hfai.distributed as dist objs = ['to broadcast'] dist.broadcast(objs, src=0) """ group = group or tdist.GroupMember.WORLD tdist.broadcast(objs, src, group=group, async_op=async_op)