diff --git a/src/radical/utils/__init__.py b/src/radical/utils/__init__.py index 3d795e639..8b7e57454 100644 --- a/src/radical/utils/__init__.py +++ b/src/radical/utils/__init__.py @@ -52,6 +52,7 @@ from .typeddict import TypedDict, TypedDictMeta, as_dict from .config import Config, DefaultConfig +from .zmq import Message from .zmq import Bridge from .zmq import Queue, Putter, Getter from .zmq import PubSub, Publisher, Subscriber diff --git a/src/radical/utils/configs/utils_default.json b/src/radical/utils/configs/utils_default.json index 5e54cbbf0..8522e6eff 100644 --- a/src/radical/utils/configs/utils_default.json +++ b/src/radical/utils/configs/utils_default.json @@ -6,7 +6,7 @@ "report" : "${RADICAL_DEFAULT_REPORT:TRUE}", "report_tgt" : "${RADICAL_DEFAULT_REPORT_TGT:stderr}", "report_dir" : "${RADICAL_DEFAULT_REPORT_DIR:$PWD}", - "profile" : "${RADICAL_DEFAULT_PROFILE:TRUE}", + "profile" : "${RADICAL_DEFAULT_PROFILE:FALSE}", "profile_dir": "${RADICAL_DEFAULT_PROFILE_DIR:$PWD}" } diff --git a/src/radical/utils/heartbeat.py b/src/radical/utils/heartbeat.py index c472dda2a..9bd9030e4 100644 --- a/src/radical/utils/heartbeat.py +++ b/src/radical/utils/heartbeat.py @@ -64,12 +64,14 @@ def __init__(self, uid, timeout, interval=1, beat_cb=None, term_cb=None, if not self._log: self._log = Logger('radical.utils.heartbeat') + self._log.debug('=== hb %s create', self._uid) + # -------------------------------------------------------------------------- # def start(self): - self._log.debug('start heartbeat') + self._log.debug('=== hb %s start', self._uid) self._watcher = mt.Thread(target=self._watch) self._watcher.daemon = True self._watcher.start() @@ -101,41 +103,56 @@ def dump(self, log): log.debug('hb dump %s: \n%s', self._uid, pprint.pformat(self._tstamps)) + # -------------------------------------------------------------------------- + # + def watch(self, uid): + + with self._lock: + if uid not in self._tstamps: + self._log.debug('=== hb %s watch %s', self._uid, uid) + self._tstamps[uid] = None + + # -------------------------------------------------------------------------- # def _watch(self): # initial heartbeat without delay if self._beat_cb: + self._log.debug('=== hb %s beat cb init', self._uid) self._beat_cb() while not self._term.is_set(): + self._log.debug('=== hb %s loop %s', self._uid, self._interval) + time.sleep(self._interval) now = time.time() if self._beat_cb: + self._log.debug('=== hb %s beat cb', self._uid) self._beat_cb() # avoid iteration over changing dict with self._lock: uids = list(self._tstamps.keys()) + self._log.debug('=== hb %s uids %s', self._uid, uids) for uid in uids: - # self._log.debug('hb %s check %s', self._uid, uid) + self._log.debug('=== hb %s check %s', self._uid, uid) with self._lock: last = self._tstamps.get(uid) if last is None: - self._log.warn('hb %s[%s]: never seen', self._uid, uid) + self._log.warn('=== hb %s inval %s', self._uid, uid) continue if now - last > self._timeout: if self._log: - self._log.warn('hb %s[%s]: %.1f - %.1f > %1.f: timeout', + self._log.warn('=== hb %s tout %s: %.1f - %.1f > %1.f', self._uid, uid, now, last, self._timeout) ret = None @@ -148,9 +165,10 @@ def _watch(self): # avoiding termination ret = True - if ret is None: + if ret in [None, False]: # could not recover: abandon mothership - self._log.warn('hb fail %s: fatal (%d)', uid, self._pid) + self._log.warn('=== hb %s fail %s: fatal (%d)', + self._uid, uid, self._pid) os.kill(self._pid, signal.SIGTERM) time.sleep(1) os.kill(self._pid, signal.SIGKILL) @@ -161,8 +179,9 @@ def _watch(self): # information for the old uid and register a new # heartbeat for the new one, so that we can immediately # begin to watch it. - self._log.info('hb recover %s -> %s (%s)', - uid, ret, self._term_cb) + assert isinstance(ret, str) + self._log.info('=== hb %s recov %s -> %s (%s)', + self._uid, uid, ret, self._term_cb) with self._lock: del self._tstamps[uid] self._tstamps[ret] = time.time() @@ -178,8 +197,8 @@ def beat(self, uid=None, timestamp=None): if not uid: uid = 'default' - # self._log.debug('hb %s beat [%s]', self._uid, uid) with self._lock: + self._log.debug('hb %s beat [%s]', self._uid, uid) self._tstamps[uid] = timestamp @@ -233,11 +252,11 @@ def wait_startup(self, uids=None, timeout=None): self._log.debug('wait time: %s', nok) break - time.sleep(0.05) + time.sleep(0.25) if len(ok) != len(uids): nok = [uid for uid in uids if uid not in ok] - self._log.debug('wait fail: %s', nok) + self._log.error('wait fail: %s', nok) return nok else: diff --git a/src/radical/utils/ids.py b/src/radical/utils/ids.py index 8fffeab43..8fdcdbcd1 100644 --- a/src/radical/utils/ids.py +++ b/src/radical/utils/ids.py @@ -108,7 +108,7 @@ def reset_counter(self, prefix, reset_all_others=False): # ------------------------------------------------------------------------------ # -def generate_id(prefix, mode=ID_SIMPLE, ns=None): +def generate_id(prefix: str, mode=ID_SIMPLE, ns=None): """ Generate a human readable, sequential ID for the given prefix. @@ -183,8 +183,8 @@ def generate_id(prefix, mode=ID_SIMPLE, ns=None): and will, for `ID_PRIVATE`, revert to `ID_UUID`. """ - if not prefix or not isinstance(prefix, str): - raise TypeError("ID generation expect prefix in basestring type") + if not isinstance(prefix, str): + raise TypeError('"prefix" must be a string, not %s' % type(prefix)) if _cache['dockerized'] and mode == ID_PRIVATE: mode = ID_UUID diff --git a/src/radical/utils/logger.py b/src/radical/utils/logger.py index 455626244..9aef1968d 100644 --- a/src/radical/utils/logger.py +++ b/src/radical/utils/logger.py @@ -331,6 +331,7 @@ def _ensure_handler(self): p = self._path n = self._name for t in self._targets: + if t in ['0', 'null'] : h = logging.NullHandler() elif t in ['-', '1', 'stdout']: h = ColorStreamHandler(sys.stdout) elif t in ['=', '2', 'stderr']: h = ColorStreamHandler(sys.stderr) diff --git a/src/radical/utils/profile.py b/src/radical/utils/profile.py index be26d2d3a..31825aaba 100644 --- a/src/radical/utils/profile.py +++ b/src/radical/utils/profile.py @@ -211,6 +211,10 @@ def __init__(self, name, ns=None, path=None): except OSError: pass # already exists + # don't open the file on disabled profilers + if not self._enabled: + return + # we set `buffering` to `1` to force line buffering. That is not idea # performance wise - but will not do an `fsync()` after writes, so OS # level buffering should still apply. This is supposed to shield diff --git a/src/radical/utils/shell.py b/src/radical/utils/shell.py index 2510b4cc6..f9f2e527d 100644 --- a/src/radical/utils/shell.py +++ b/src/radical/utils/shell.py @@ -39,7 +39,8 @@ def sh_quote(data): # ------------------------------------------------------------------------------ # -def sh_callout(cmd, stdout=True, stderr=True, shell=False, env=None): +def sh_callout(cmd, stdout=True, stderr=True, + shell=False, env=None, cwd=None): ''' call a shell command, return `[stdout, stderr, retval]`. ''' @@ -54,7 +55,8 @@ def sh_callout(cmd, stdout=True, stderr=True, shell=False, env=None): if stderr : stderr = sp.PIPE else : stderr = None - p = sp.Popen(cmd, stdout=stdout, stderr=stderr, shell=shell, env=env) + p = sp.Popen(cmd, stdout=stdout, stderr=stderr, + shell=shell, env=env, cwd=cwd) if not stdout and not stderr: ret = p.wait() @@ -67,7 +69,8 @@ def sh_callout(cmd, stdout=True, stderr=True, shell=False, env=None): # ------------------------------------------------------------------------------ # -def sh_callout_bg(cmd, stdout=None, stderr=None, shell=False, env=None): +def sh_callout_bg(cmd, stdout=None, stderr=None, + shell=False, env=None, cwd=None): ''' call a shell command in the background. Do not attempt to pipe STDOUT/ERR, but only support writing to named files. @@ -84,7 +87,7 @@ def sh_callout_bg(cmd, stdout=None, stderr=None, shell=False, env=None): # convert string into arg list if needed if not shell and is_string(cmd): cmd = shlex.split(cmd) - sp.Popen(cmd, stdout=stdout, stderr=stderr, shell=shell, env=env) + sp.Popen(cmd, stdout=stdout, stderr=stderr, shell=shell, env=env, cwd=cwd) return @@ -92,7 +95,7 @@ def sh_callout_bg(cmd, stdout=None, stderr=None, shell=False, env=None): # ------------------------------------------------------------------------------ # def sh_callout_async(cmd, stdin=True, stdout=True, stderr=True, - shell=False, env=None): + shell=False, env=None, cwd=None): ''' Run a command, and capture stdout/stderr if so flagged. The call will @@ -110,6 +113,9 @@ def sh_callout_async(cmd, stdin=True, stdout=True, stderr=True, shell: True, False [default] - pass to popen + cwd: string + - working directory for command to run in + PROC: - PROC.stdout : `queue.Queue` instance delivering stdout lines - PROC.stderr : `queue.Queue` instance delivering stderr lines @@ -133,7 +139,7 @@ class _P(object): ''' # ---------------------------------------------------------------------- - def __init__(self, cmd, stdin, stdout, stderr, shell, env): + def __init__(self, cmd, stdin, stdout, stderr, shell, env, cwd): cmd = cmd.strip() @@ -165,6 +171,7 @@ def __init__(self, cmd, stdin, stdout, stderr, shell, env): stderr=self._err_w, shell=shell, env=env, + cwd=cwd, bufsize=1) t = mt.Thread(target=self._watch) @@ -277,7 +284,7 @@ def _watch(self): # -------------------------------------------------------------------------- return _P(cmd=cmd, stdin=stdin, stdout=stdout, stderr=stderr, - shell=shell, env=env) + shell=shell, env=env, cwd=cwd) # ------------------------------------------------------------------------------ diff --git a/src/radical/utils/typeddict.py b/src/radical/utils/typeddict.py index 7573781bd..f30091380 100644 --- a/src/radical/utils/typeddict.py +++ b/src/radical/utils/typeddict.py @@ -107,7 +107,7 @@ class TypedDict(dict, metaclass=TypedDictMeta): # -------------------------------------------------------------------------- # - def __init__(self, from_dict=None): + def __init__(self, from_dict=None, **kwargs): ''' Create a typed dictionary (tree) from `from_dict`. @@ -131,10 +131,19 @@ def __init__(self, from_dict=None): verify Names with a leading underscore are not supported. + + Supplied `from_dict` and kwargs are used to initialize the object + data -- the `kwargs` take preceedence over the `from_dict` if both + are specified (note that `from_dict` and `self` are invalid + `kwargs`). ''' + self.update(copy.deepcopy(self._defaults)) self.update(from_dict) + if kwargs: + self.update(kwargs) + # -------------------------------------------------------------------------- # @@ -297,8 +306,7 @@ def __str__(self): return str(self._data) def __repr__(self): - return '<%s object, schema keys: %s>' % \ - (type(self).__qualname__, tuple(self._schema.keys())) + return '%s: %s' % (type(self).__qualname__, str(self)) # -------------------------------------------------------------------------- diff --git a/src/radical/utils/zmq/__init__.py b/src/radical/utils/zmq/__init__.py index d42a28e6c..d6d2d4047 100644 --- a/src/radical/utils/zmq/__init__.py +++ b/src/radical/utils/zmq/__init__.py @@ -7,12 +7,13 @@ from .bridge import Bridge -from .queue import Queue, Putter, Getter -from .pubsub import PubSub, Publisher, Subscriber +from .queue import Queue, Putter, Getter, test_queue +from .pubsub import PubSub, Publisher, Subscriber, test_pubsub from .pipe import Pipe, MODE_PUSH, MODE_PULL from .client import Client from .server import Server from .registry import Registry, RegistryClient +from .message import Message # ------------------------------------------------------------------------------ diff --git a/src/radical/utils/zmq/bridge.py b/src/radical/utils/zmq/bridge.py index ef1aa913c..06be1b26b 100644 --- a/src/radical/utils/zmq/bridge.py +++ b/src/radical/utils/zmq/bridge.py @@ -3,11 +3,17 @@ import threading as mt +from typing import Optional + from ..logger import Logger from ..profile import Profiler from ..config import Config from ..json_io import read_json, write_json +QUEUE = 'QUEUE' +PUBSUB = 'PUBSUB' +UNKNOWN = 'UNKNOWN' + # ------------------------------------------------------------------------------ # @@ -43,13 +49,14 @@ def __init__(self, cfg): self._channel = self._cfg.channel self._uid = self._cfg.uid self._pwd = self._cfg.path + + if not self._pwd: + self._pwd = os.getcwd() + self._log = Logger(name=self._uid, ns='radical.utils', level=self._cfg.log_lvl, path=self._pwd) self._prof = Profiler(name=self._uid, path=self._pwd) - if self._pwd is None: - self._pwd = os.getcwd() - if 'hb' in self._uid or 'heartbeat' in self._uid: self._prof.disable() else: @@ -131,25 +138,33 @@ def start(self): # -------------------------------------------------------------------------- # @staticmethod - def create(cfg): + def create(channel : str, + kind : Optional[str] = None, + cfg : Optional[dict] = None): + + # FIXME: add other config parameters: batch size, log level, etc. # NOTE: I'd rather have this as class data than as stack data, but # python stumbles over circular imports at that point :/ # Another option though is to discover and dynamically load # components. + from .pubsub import PubSub from .queue import Queue - _btypemap = {'pubsub' : PubSub, - 'queue' : Queue} + _btypemap = {PUBSUB: PubSub, + QUEUE : Queue} - kind = cfg['kind'] + if not kind: + if 'queue' in channel.lower(): kind = QUEUE + elif 'pubsub' in channel.lower(): kind = PUBSUB + else : kind = UNKNOWN if kind not in _btypemap: raise ValueError('unknown bridge type (%s)' % kind) btype = _btypemap[kind] - bridge = btype(cfg) + bridge = btype(channel, cfg=cfg) return bridge diff --git a/src/radical/utils/zmq/message.py b/src/radical/utils/zmq/message.py new file mode 100644 index 000000000..7269e9821 --- /dev/null +++ b/src/radical/utils/zmq/message.py @@ -0,0 +1,59 @@ + +from typing import Optional, Dict, Any + +import msgpack + +from ..typeddict import TypedDict + + +# ------------------------------------------------------------------------------ +# +class Message(TypedDict): + + _schema = { + '_msg_type': str, + } + + _defaults = { + '_msg_type': None, + } + + _msg_types = dict() + + + # -------------------------------------------------------------------------- + # + def _verify(self): + assert self._msg_type + + + @staticmethod + def register_msg_type(msg_type, msg_class): + Message._msg_types[msg_type] = msg_class + + + @staticmethod + def deserialize(data: Dict[str, Any]): + + msg_type = data.get('_msg_type') + + if msg_type is None: + raise ValueError('no message type defined') + + if msg_type not in Message._msg_types: + known = list(Message._msg_types.keys()) + raise ValueError('unknown message type [%s]: %s' % (msg_type, known)) + + return Message._msg_types[msg_type](from_dict=data) + + + def packb(self): + return msgpack.packb(self) + + @staticmethod + def unpackb(bdata): + return Message.deserialize(msgpack.unpackb(bdata)) + + +# ------------------------------------------------------------------------------ + diff --git a/src/radical/utils/zmq/pubsub.py b/src/radical/utils/zmq/pubsub.py index 993d48ffa..ebb4949c6 100644 --- a/src/radical/utils/zmq/pubsub.py +++ b/src/radical/utils/zmq/pubsub.py @@ -1,21 +1,25 @@ # pylint: disable=protected-access import zmq +import time import msgpack import threading as mt +from typing import Optional + from ..atfork import atfork from ..config import Config from ..ids import generate_id, ID_CUSTOM from ..url import Url -from ..misc import is_string, as_string, as_bytes, as_list, noop +from ..misc import as_string, as_bytes, as_list, noop from ..host import get_hostip from ..logger import Logger from ..profile import Profiler +from ..debug import get_stacktrace, get_caller_name, print_stacktrace from .bridge import Bridge -from .utils import no_intr # , log_bulk +from .utils import no_intr , log_bulk # ------------------------------------------------------------------------------ @@ -37,30 +41,20 @@ def _atfork_child(): # ------------------------------------------------------------------------------ # -# Notifications between components are based on pubsub channels. Those channels -# have different scope (bound to the channel name). Only one specific topic is -# predefined: 'state' will be used for unit state updates. -# class PubSub(Bridge): # -------------------------------------------------------------------------- # - def __init__(self, cfg=None, channel=None): - - if cfg and not channel and is_string(cfg): - # allow construction with only channel name - channel = cfg - cfg = None + def __init__(self, channel: str, cfg: Optional[dict] = None): - if cfg : cfg = Config(cfg=cfg) - elif channel: cfg = Config(cfg={'channel': channel}) - else: raise RuntimeError('PubSub needs cfg or channel parameter') - - if not cfg.channel: - raise ValueError('no channel name provided for pubsub') + if cfg: + # create deep copy + cfg = Config(cfg=cfg) + else: + cfg = Config() if not cfg.uid: - cfg.uid = generate_id('%s.bridge.%%(counter)04d' % cfg.channel, + cfg.uid = generate_id('%s.bridge.%%(counter)04d' % channel, ID_CUSTOM) super(PubSub, self).__init__(cfg) @@ -104,19 +98,19 @@ def _bridge_initialize(self): self._lock = mt.Lock() self._ctx = zmq.Context.instance() # rely on GC for destruction - self._pub = self._ctx.socket(zmq.XSUB) - self._pub.linger = _LINGER_TIMEOUT - self._pub.hwm = _HIGH_WATER_MARK - self._pub.bind('tcp://*:*') + self._xpub = self._ctx.socket(zmq.XSUB) + self._xpub.linger = _LINGER_TIMEOUT + self._xpub.hwm = _HIGH_WATER_MARK + self._xpub.bind('tcp://*:*') - self._sub = self._ctx.socket(zmq.XPUB) - self._sub.linger = _LINGER_TIMEOUT - self._sub.hwm = _HIGH_WATER_MARK - self._sub.bind('tcp://*:*') + self._xsub = self._ctx.socket(zmq.XPUB) + self._xsub.linger = _LINGER_TIMEOUT + self._xsub.hwm = _HIGH_WATER_MARK + self._xsub.bind('tcp://*:*') # communicate the bridge ports to the parent process - _addr_pub = as_string(self._pub.getsockopt(zmq.LAST_ENDPOINT)) - _addr_sub = as_string(self._sub.getsockopt(zmq.LAST_ENDPOINT)) + _addr_pub = as_string(self._xpub.getsockopt(zmq.LAST_ENDPOINT)) + _addr_sub = as_string(self._xsub.getsockopt(zmq.LAST_ENDPOINT)) # store addresses self._addr_pub = Url(_addr_pub) @@ -129,10 +123,13 @@ def _bridge_initialize(self): self._log.info('bridge pub on %s: %s', self._uid, self._addr_pub) self._log.info(' sub on %s: %s', self._uid, self._addr_sub) + # make sure bind is active + time.sleep(0.1) + # start polling for messages self._poll = zmq.Poller() - self._poll.register(self._pub, zmq.POLLIN) - self._poll.register(self._sub, zmq.POLLIN) + self._poll.register(self._xpub, zmq.POLLIN) + self._poll.register(self._xsub, zmq.POLLIN) # -------------------------------------------------------------------------- @@ -151,28 +148,28 @@ def _bridge_work(self): # timeout in ms socks = dict(self._poll.poll(timeout=10)) - if self._sub in socks: + if self._xsub in socks: # if the sub socket signals a message, it's likely # a topic subscription. Forward that to the pub # channel, so the bridge subscribes for the respective # message topic. - msg = self._sub.recv() - self._pub.send(msg) + msg = self._xsub.recv() + self._xpub.send(msg) self._prof.prof('subscribe', uid=self._uid, msg=msg) - # log_bulk(self._log, '~~1 %s' % self.channel, [msg]) + # log_bulk(self._log, '~~1 %s' % self.uid, [msg]) - if self._pub in socks: + if self._xpub in socks: # if the pub socket signals a message, get the message # and forward it to the sub channel, no questions asked. - msg = self._pub.recv() - self._sub.send(msg) + msg = self._xpub.recv() + self._xsub.send(msg) # self._prof.prof('msg_fwd', uid=self._uid, msg=msg) - # log_bulk(self._log, '<> %s' % self.channel, [msg]) + # log_bulk(self._log, '<> %s' % self.uid, [msg]) # ------------------------------------------------------------------------------ @@ -190,18 +187,19 @@ def __init__(self, channel, url=None, log=None, prof=None, path=None): self._lock = mt.Lock() # FIXME: no uid ns - self._uid = generate_id('%s.pub.%s' % (self._channel, - '%(counter)04d'), ID_CUSTOM) + self._uid = generate_id('%s.pub.%s' % (self._channel, + '%(counter)04d'), ID_CUSTOM) if not self._url: self._url = Bridge.get_config(channel, path).pub if not log: - self._log = Logger(name=self._uid, ns='radical.utils.zmq') - # level='debug') + self._log = Logger(name=self._uid, ns='radical.utils.zmq', + path=path) if not prof: - self._prof = Profiler(name=self._uid, ns='radical.utils.zmq') + self._prof = Profiler(name=self._uid, ns='radical.utils.zmq', + path=path) self._prof.disable() if 'hb' in self._uid or 'heartbeat' in self._uid: @@ -215,6 +213,8 @@ def __init__(self, channel, url=None, log=None, prof=None, path=None): self._socket.hwm = _HIGH_WATER_MARK self._socket.connect(self._url) + time.sleep(0.1) + # -------------------------------------------------------------------------- # @@ -242,6 +242,7 @@ def put(self, topic, msg): assert isinstance(topic, str), 'invalid topic type' # self._log.debug('=== put %s : %s: %s', topic, self.channel, msg) + # self._log.debug('=== put %s: %s', msg, get_stacktrace()) # self._prof.prof('put', uid=self._uid, msg=msg) # log_bulk(self._log, '-> %s' % topic, [msg]) @@ -275,7 +276,7 @@ def _get_nowait(socket, lock, timeout, log, prof): topic, bmsg = data.split(b' ', 1) msg = msgpack.unpackb(bmsg) - # log.debug(' <- %s: %s', topic, msg) + log.debug(' <- %s: %s', topic, msg) return [as_string(topic), as_string(msg)] @@ -293,7 +294,7 @@ def _listener(sock, lock, term, callbacks, log, prof): # this list is dynamic topic, msg = Subscriber._get_nowait(sock, lock, 500, log, prof) - # log.debug(' <- %s: %s', topic, msg) + log.debug(' <- %s: %s', topic, msg) if topic: for cb, _lock in callbacks: @@ -304,6 +305,11 @@ def _listener(sock, lock, term, callbacks, log, prof): cb(topic, msg) else: cb(topic, msg) + except SystemExit: + log.info('callback called sys.exit') + term.set() + break + except: log.exception('callback error') except: @@ -339,6 +345,9 @@ def __init__(self, channel, url=None, topic=None, cb=None, self._uid = generate_id('%s.sub.%s' % (self._channel, '%(counter)04d'), ID_CUSTOM) + if not self._topics: + self._topics = [] + if not self._url: self._url = Bridge.get_config(channel, path).sub @@ -363,6 +372,8 @@ def __init__(self, channel, url=None, topic=None, cb=None, self._sock.hwm = _HIGH_WATER_MARK self._sock.connect(self._url) + time.sleep(0.1) + # only allow `get()` and `get_nowait()` self._interactive = True @@ -446,6 +457,7 @@ def subscribe(self, topic, cb=None, lock=None): # log_bulk(self._log, '~~2 %s' % topic, [topic]) with self._lock: + # self._log.debug('==== subscribe for %s', topic) no_intr(self._sock.setsockopt, zmq.SUBSCRIBE, as_bytes(topic)) if topic not in self._topics: @@ -518,5 +530,73 @@ def get_nowait(self, timeout=None): return [None, None] +# ------------------------------------------------------------------------------ +# +def test_pubsub(channel, addr_pub, addr_sub): + + return {} + + topic = 'test' + + c_a = 1 + c_b = 2 + data = dict() + + for i in 'ABCD': + data[i] = dict() + for j in 'AB': + data[i][j] = 0 + + def cb(uid, topic, msg): + if 'idx' not in msg: + return + if msg['idx'] is None: + return False + data[uid][msg['src']] += 1 + + cb_C = lambda t,m: cb('C', t, m) + cb_D = lambda t,m: cb('D', t, m) + + Subscriber(channel=channel, url=addr_sub, topic=topic, cb=cb_C) + Subscriber(channel=channel, url=addr_sub, topic=topic, cb=cb_D) + + # -------------------------------------------------------------------------- + def work_pub(uid, n, delay): + + pub = Publisher(channel=channel, url=addr_pub) + idx = 0 + + while idx < n: + time.sleep(delay) + pub.put(topic, {'src': uid, 'idx': idx}) + idx += 1 + data[uid][uid] += 1 + + # send EOF + pub.put(topic, {'src': uid, 'idx': None}) + # -------------------------------------------------------------------------- + + t_a = mt.Thread(target=work_pub, args=['A', c_a, 0.001]) + t_b = mt.Thread(target=work_pub, args=['B', c_b, 0.001]) + + t_a.start() + t_b.start() + + t_a.join() + t_b.join() + + time.sleep(0.1) + + assert data['A']['A'] == c_a + assert data['B']['B'] == c_b + + assert data['C']['A'] + data['C']['B'] + \ + data['D']['A'] + data['D']['B'] == 2 * (c_a + c_b) + + # print('==== %.1f %s [%s]' % (time.time(), channel, get_caller_name())) + + return data + + # ------------------------------------------------------------------------------ diff --git a/src/radical/utils/zmq/queue.py b/src/radical/utils/zmq/queue.py index 68ee02b25..9914fb19e 100644 --- a/src/radical/utils/zmq/queue.py +++ b/src/radical/utils/zmq/queue.py @@ -7,11 +7,13 @@ import threading as mt +from typing import Optional + from ..atfork import atfork from ..config import Config from ..ids import generate_id, ID_CUSTOM from ..url import Url -from ..misc import is_string, as_string, as_bytes, as_list, noop +from ..misc import as_string, as_bytes, as_list, noop from ..host import get_hostip from ..logger import Logger from ..profile import Profiler @@ -20,11 +22,11 @@ from .bridge import Bridge from .utils import no_intr +# NOTE: the log bulk method is frequently called and slow # from .utils import log_bulk # from .utils import prof_bulk -# FIXME: the log bulk method is frequently called and slow # -------------------------------------------------------------------------- # @@ -85,7 +87,7 @@ def _atfork_child(): # class Queue(Bridge): - def __init__(self, cfg=None, channel=None): + def __init__(self, channel: str, cfg: Optional[dict] = None): ''' This Queue type sets up an zmq channel of this kind: @@ -106,17 +108,17 @@ def __init__(self, cfg=None, channel=None): addresses as obj.addr_put and obj.addr_get. ''' - if cfg and not channel and is_string(cfg): - # allow construction with only channel name - channel = cfg - cfg = None - - if cfg : cfg = Config(cfg=cfg) - elif channel: cfg = Config(cfg={'channel': channel}) - else: raise RuntimeError('Queue needs cfg or channel parameter') + if cfg: + # create deep copy + cfg = Config(cfg=cfg) + else: + cfg = Config() - if not cfg.channel: - raise ValueError('no channel name provided for queue') + # ensure channel is set in config + if cfg.channel: + assert cfg.channel == channel + else: + cfg.channel = channel if not cfg.uid: cfg.uid = generate_id('%s.bridge.%%(counter)04d' % cfg.channel, @@ -240,7 +242,7 @@ def _bridge_work(self): msgs = msgpack.unpackb(data[1]) # prof_bulk(self._prof, 'poll_put_recv', msgs) # log_bulk(self._log, '<> %s' % qname, msgs) - self._log.debug('put %s: %s ! ', qname, len(msgs)) + # self._log.debug('put %s: %s ! ', qname, len(msgs)) if qname not in buf: buf[qname] = list() @@ -270,6 +272,8 @@ def _bridge_work(self): if qname in buf: msgs = buf[qname][:self._bulk_size] else: + # self._log.debug('get: %s not in %s', qname, + # list(buf.keys())) msgs = list() # log_bulk(self._log, '>< %s' % qname, msgs) @@ -323,7 +327,7 @@ def __init__(self, channel, url=None, log=None, prof=None, path=None): ID_CUSTOM) if not self._url: - self._url = Bridge.get_config(channel, path).put + self._url = Bridge.get_config(channel, path).get('put') if not self._url: raise ValueError('no contact url specified, no config found') @@ -412,6 +416,7 @@ def _get_nowait(url, qname=None, timeout=None, uid=None): # timeout in ms if not info['requested']: # send the request *once* per recieval (got lock above) + # FIXME: why is this sent repeatedly? # logger.debug('=== => from %s[%s]', uid, qname) no_intr(info['socket'].send, as_bytes(qname)) info['requested'] = True @@ -460,7 +465,6 @@ def _listener(url, qname=None, uid=None): continue msgs = Getter._get_nowait(url, qname=qname, timeout=500, uid=uid) - BULK = True if msgs: @@ -547,7 +551,7 @@ def __init__(self, channel, url=None, cb=None, ID_CUSTOM) if not self._url: - self._url = Bridge.get_config(channel, path).get + self._url = Bridge.get_config(channel, path).get('get') if not self._url: raise ValueError('no contact url specified, no config found') @@ -735,5 +739,70 @@ def get_nowait(self, qname=None, timeout=None): # timeout in ms return None +# ------------------------------------------------------------------------------ +# +def test_queue(channel, addr_pub, addr_sub): + + c_a = 200 + c_b = 400 + data = dict() + + for i in 'ABCD': + data[i] = dict() + for j in 'AB': + data[i][j] = 0 + + def cb(uid, msg): + if msg['idx'] is None: + return False + data[uid][msg['src']] += 1 + + cb_C = lambda t,m: cb('C', m) + cb_D = lambda t,m: cb('D', m) + + Getter(channel=channel, url=addr_sub, cb=cb_C) + Getter(channel=channel, url=addr_sub, cb=cb_D) + + # -------------------------------------------------------------------------- + def work_pub(uid, n, delay): + + pub = Putter(channel=channel, url=addr_pub) + idx = 0 + + while idx < n: + time.sleep(delay) + pub.put({'src': uid, + 'idx': idx}) + idx += 1 + data[uid][uid] += 1 + + # send EOF + pub.put({'src': uid, + 'idx': None}) + # -------------------------------------------------------------------------- + + t_a = mt.Thread(target=work_pub, args=['A', c_a, 0.001]) + t_b = mt.Thread(target=work_pub, args=['B', c_b, 0.001]) + + t_a.start() + t_b.start() + + t_a.join() + t_b.join() + + time.sleep(0.1) + + import pprint + pprint.pprint(data) + + assert data['A']['A'] == c_a + assert data['B']['B'] == c_b + + assert data['C']['A'] + data['C']['B'] + \ + data['D']['A'] + data['D']['B'] == 2 * (c_a + c_b) + + return data + + # ------------------------------------------------------------------------------ diff --git a/src/radical/utils/zmq/registry.py b/src/radical/utils/zmq/registry.py index d9331ac6d..9f143f257 100644 --- a/src/radical/utils/zmq/registry.py +++ b/src/radical/utils/zmq/registry.py @@ -1,4 +1,5 @@ +import atexit import shelve from typing import List, Optional, Any @@ -9,6 +10,18 @@ from .server import Server from .client import Client +_registries = list() + + +# ------------------------------------------------------------------------------ +# +def _flush_registries(): + for _reg in _registries: + _reg.stop() + + +atexit.register(_flush_registries) + # ------------------------------------------------------------------------------ # @@ -22,9 +35,10 @@ class Registry(Server): # def __init__(self, url : Optional[str] = None, uid : Optional[str] = None, + path : Optional[str] = None, persistent: bool = False) -> None: - super().__init__(url=url, uid=uid) + super().__init__(url=url, uid=uid, path=path) if persistent: self._data = shelve.open('%s.db' % self._uid, writeback=True) @@ -35,15 +49,27 @@ def __init__(self, url : Optional[str] = None, self.register_request('get', self.get) self.register_request('keys', self.keys) self.register_request('del', self.delitem) + self.register_request('dump', self.dump) # -------------------------------------------------------------------------- # - def stop(self) -> None: + def dump(self, name: str = None) -> None: if isinstance(self._data, dict): - write_json(self._data, '%s.json' % self._uid) - else: + if name: + write_json(self._data, '%s.%s.json' % (self._uid, name)) + else: + write_json(self._data, '%s.json' % self._uid) + + + # -------------------------------------------------------------------------- + # + def stop(self) -> None: + + self.dump() + + if isinstance(self._data, shelve.Shelf): self._data.close() super().stop() @@ -60,7 +86,7 @@ def put(self, key: str, val: Any) -> None: for elem in path: - if elem not in this: + if elem not in this or this[elem] is None: this[elem] = dict() this = this[elem] @@ -81,28 +107,51 @@ def get(self, key: str) -> Optional[str]: leaf = elems[-1] for elem in path: - - this = this.get(elem) + this = this.get(elem, {}) if not this: - return None + break + + if this is None: + this = dict() - return this.get(leaf) + val = this.get(leaf) + return val # -------------------------------------------------------------------------- # - def keys(self) -> List[str]: + def keys(self, pwd: Optional[str] = None) -> List[str]: + + this = self._data - return list(self._data.keys()) + if pwd: + path = pwd.split('.') + for elem in path: + this = this.get(elem, {}) + if not this: + break + + if this is None: + this = dict() + + return list(this.keys()) # -------------------------------------------------------------------------- # def delitem(self, key: str) -> None: - del self._data[key] - if not isinstance(self._data, dict): - self._data.sync() + this = self._data + + if key: + path = key.split('.') + for elem in path[:-1]: + this = this.get(elem, {}) + if not this: + break + + if this: + del this[path[-1]] # ------------------------------------------------------------------------------ @@ -117,14 +166,29 @@ class RegistryClient(Client, DictMixin): # -------------------------------------------------------------------------- # - def __init__(self, url: str) -> None: + def __init__(self, url: str, + pwd: Optional[str] = None) -> None: + + self._url = url + self._pwd = pwd super().__init__(url=url) + # -------------------------------------------------------------------------- + # + def dump(self, name: str = None) -> None: + + return self.request(cmd='dump', name=name) + + # -------------------------------------------------------------------------- # verbose API - def get(self, key: str, default: Optional[str] = None) -> Optional[Any]: + def get(self, key : str, + default: Optional[Any] = None) -> Optional[Any]: + + if self._pwd: + key = self._pwd + '.' + key try: return self.request(cmd='get', key=key) @@ -134,7 +198,11 @@ def get(self, key: str, default: Optional[str] = None) -> Optional[Any]: def put(self, key: str, val: Any) -> None: + + if self._pwd: + key = self._pwd + '.' + key ret = self.request(cmd='put', key=key, val=val) + assert ret is None return ret @@ -142,20 +210,26 @@ def put(self, key: str, # -------------------------------------------------------------------------- # dict mixin API def __getitem__(self, key: str) -> Optional[Any]: + return self.get(key) def __setitem__(self, key: str, val: Any) -> None: + return self.put(key, val) def __delitem__(self, key: str) -> None: + + if self._pwd: + key = self._pwd + '.' + key ret = self.request(cmd='del', key=key) assert ret is None def keys(self) -> List[str]: - ret = self.request(cmd='keys') + + ret = self.request(cmd='keys', pwd=self._pwd) assert isinstance(ret, list) return ret diff --git a/src/radical/utils/zmq/server.py b/src/radical/utils/zmq/server.py index 7f36516aa..3a83f6fe5 100644 --- a/src/radical/utils/zmq/server.py +++ b/src/radical/utils/zmq/server.py @@ -271,6 +271,7 @@ def _work(self) -> None: while not self._term.is_set(): + event = dict(no_intr(self._poll.poll, timeout=100)) if self._sock not in event: diff --git a/tests/unittests/test_heartbeat.py b/tests/unittests/test_heartbeat.py index ad740196d..82c581d69 100755 --- a/tests/unittests/test_heartbeat.py +++ b/tests/unittests/test_heartbeat.py @@ -106,9 +106,11 @@ def proc(): try: while True: - if time.time() < t0 + 3: hb.beat('short') - if time.time() < t0 + 5: hb.beat('long') - time.sleep(0.05) + if time.time() < t0 + 3: hb.beat() + elif time.time() < t0 + 5: hb.beat() + else: break + time.sleep(0.1) + while True: time.sleep(1) @@ -127,7 +129,7 @@ def proc(): assert p.is_alive() # but it should have a zero exit value after 2 more seconds - time.sleep(2) + time.sleep(6) assert not p.is_alive() assert p.exitcode @@ -140,7 +142,7 @@ def proc(): # run tests if called directly if __name__ == "__main__": - test_hb_default() + # test_hb_default() test_hb_uid() diff --git a/tests/unittests/test_profiler.py b/tests/unittests/test_profiler.py index b42b26127..9f5ff901a 100755 --- a/tests/unittests/test_profiler.py +++ b/tests/unittests/test_profiler.py @@ -143,8 +143,6 @@ def _assert_profiler(key, val, res): if k.startswith('RADICAL'): del os.environ[k] - _assert_profiler('', '', True) - for val, res in [ ['false', False], ['', True ], diff --git a/tests/unittests/test_typeddict.py b/tests/unittests/test_typeddict.py index 784afbffe..b3d80fb97 100644 --- a/tests/unittests/test_typeddict.py +++ b/tests/unittests/test_typeddict.py @@ -346,7 +346,7 @@ class TDSchemed(TypedDict): # `__str__` method checked self.assertEqual('%s' % tds, '{}') # `__repr__` method checked - self.assertIn('TDSchemed object, schema keys', '%r' % tds) + self.assertIn('TDSchemed: ', '%r' % tds) # -------------------------------------------------------------------------- # @@ -480,11 +480,13 @@ class TD2Base(TD1Base): _cast = False _schema = { - 'base_int': float + 'base_int': float, + 'sub_bool': bool } _defaults = { - 'base_int': .5 + 'base_int': .5, + 'sub_bool': True } class TD3Base(TD2Base): @@ -510,6 +512,9 @@ class TD3Base(TD2Base): self.assertIs(getattr(TD2Base, '_schema')['base_int'], float) self.assertIs(getattr(TD1Base, '_schema')['base_int'], int) + self.assertIs(getattr(TD2Base, '_schema')['sub_bool'], bool) + self.assertIs(getattr(TD3Base, '_schema')['sub_bool'], bool) + # inherited "_self_default" from TD1Base (default value is False) self.assertTrue(getattr(TD3Base, '_self_default')) @@ -520,7 +525,7 @@ class TD3Base(TD2Base): self.assertTrue(getattr(TD1Base, '_cast')) # inherited from TD1Base ("_schema") - td3 = TD3Base({'base_int': 10, 'base_str': 20}) + td3 = TD3Base({'base_int': 10, 'base_str': 20, 'sub_bool': False}) # exception due to `TD3Base._cast = False` (inherited from TD2Base) with self.assertRaises(TypeError): td3.verify() @@ -532,8 +537,10 @@ class TD3Base(TD2Base): td3.verify() self.assertIsInstance(td3.base_int, float) self.assertIsInstance(td3.base_str, str) + self.assertIsInstance(td3.sub_bool, bool) self.assertEqual(td3.base_int, 10.) self.assertEqual(td3.base_str, '20') + self.assertEqual(td3.sub_bool, False) # -------------------------------------------------------------------------- # diff --git a/tests/unittests/test_zmq_pubsub.py b/tests/unittests/test_zmq_pubsub.py index 401db5a15..58bdd0990 100755 --- a/tests/unittests/test_zmq_pubsub.py +++ b/tests/unittests/test_zmq_pubsub.py @@ -40,9 +40,15 @@ def test_zmq_pubsub(): 'stall_hwm': 1, }) - b = ru.zmq.PubSub(cfg) + b = ru.zmq.PubSub('test', cfg) b.start() + assert b.type_in == 'pub' + assert b.type_out == 'sub' + + assert b.addr_in == b.addr_pub + assert b.addr_out == b.addr_sub + assert b.addr_in != b.addr_out assert b.addr_in == b.addr_pub assert b.addr_out == b.addr_sub @@ -103,6 +109,9 @@ def work_pub(uid, n, delay): assert data['C']['A'] + data['C']['B'] + \ data['D']['A'] + data['D']['B'] == 2 * (c_a + c_b) + import pprint + pprint.pprint(data) + # ------------------------------------------------------------------------------ # run tests if called directly diff --git a/tests/unittests/test_zmq_queue.py b/tests/unittests/test_zmq_queue.py index 12bd68bc7..e8b648f23 100755 --- a/tests/unittests/test_zmq_queue.py +++ b/tests/unittests/test_zmq_queue.py @@ -265,7 +265,7 @@ def get_msg_a(msgs): data['get'][uid] = list() data['get'][uid].append(uid) - b = ru.zmq.Queue(cfg) + b = ru.zmq.Queue('test', cfg) b.start() assert b.addr_in != b.addr_out diff --git a/tests/unittests/test_zmq_registry.py b/tests/unittests/test_zmq_registry.py index 90b98af0a..5455662d5 100755 --- a/tests/unittests/test_zmq_registry.py +++ b/tests/unittests/test_zmq_registry.py @@ -18,20 +18,19 @@ def test_zmq_registry(mocked_prof): try: assert r.addr - c = ru.zmq.RegistryClient(url=r.addr) + c = ru.zmq.RegistryClient(url=r.addr, pwd='oops') c.put('foo.bar.buz', {'biz': 11}) assert c.get('foo') == {'bar': {'buz': {'biz': 11}}} - assert c.get('foo.bar.buz.biz') == 11 + assert c.get('foo.bar.buz.biz') == 11 assert c.get('foo.bar.buz.biz.boz') is None - assert c.get('foo') == {'bar': {'buz': {'biz': 11}}} c.put('foo.bar.buz', {'biz': 42}) assert c.get('foo.bar.buz.biz') == 42 - assert c['foo.bar.buz.biz'] == 42 + assert c['foo.bar.buz.biz'] == 42 assert c['foo']['bar']['buz']['biz'] == 42 - assert c['foo.bar.buz.biz.boz'] is None + assert c['foo.bar.buz.biz.boz'] is None assert 'foo' in c assert c.keys() == ['foo'] @@ -40,10 +39,18 @@ def test_zmq_registry(mocked_prof): assert c.keys() == [] finally: + if c: + c.close() + try: + c = ru.zmq.RegistryClient(url=r.addr) + assert c.keys() == ['oops'] + + finally: if c: c.close() + r.dump() r.stop() r.wait() diff --git a/tests/unittests/test_zmq_server.py b/tests/unittests/test_zmq_server.py old mode 100644 new mode 100755 index c5897b03c..8ad860d9a --- a/tests/unittests/test_zmq_server.py +++ b/tests/unittests/test_zmq_server.py @@ -178,7 +178,7 @@ def test_server_class(self, mocked_profiler, mocked_logger): c.request('no_registered_cmd') with self.assertRaisesRegex(RuntimeError, - '.* _test_0.* takes 1 positional argument'): + '.*_test_0.* takes 1 positional argument'): c.request('test_0', None) ret = c.request('test_0')