Source code for pykern.api.util

"""API constants

:copyright: Copyright (c) 2025 RadiaSoft LLC.  All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""

# Limit pykern imports
from pykern.pkcollections import PKDict
import datetime
import enum
import inspect
import msgpack
import pykern.util

#: API that authenticates connections (needed for client)
AUTH_API_NAME = "authenticate_connection"

#: API version for AUTH (and for pykern.api)
AUTH_API_VERSION = 658584001

# Protocol code shared between client & server, not public

# A bit of type checking
_MSG_KIND_BASE = 777500


[docs] class MsgKind(enum.Enum): CALL = _MSG_KIND_BASE + 1 REPLY = _MSG_KIND_BASE + 2 SUBSCRIBE = _MSG_KIND_BASE + 3 UNSUBSCRIBE = _MSG_KIND_BASE + 4
[docs] def is_call(self): return self is self.CALL
[docs] def is_reply(self): return self is self.REPLY
[docs] def is_subscribe(self): return self is self.SUBSCRIBE
[docs] def is_unsubscribe(self): return self is self.UNSUBSCRIBE
_MSG_KIND_IS_VALID = PKDict( client=frozenset((MsgKind.REPLY, MsgKind.UNSUBSCRIBE)), server=frozenset((MsgKind.CALL, MsgKind.SUBSCRIBE, MsgKind.UNSUBSCRIBE)), ) _SUBSCRIPTION_ATTR = "pykern_api_util_subscription"
[docs] class APICallError(pykern.util.APIError): """Raised when call execution ends in exception or other error""" def __init__(self, error): super().__init__("error={}", error)
[docs] class APIDisconnected(pykern.util.APIError): """Raised when remote server closed or other error""" def __init__(self): super().__init__("")
[docs] class APIForbidden(pykern.util.APIError): """Raised for forbidden or protocol error""" def __init__(self): super().__init__("")
[docs] class APIKindError(pykern.util.APIError): """Raised when kind mismatch""" def __init__(self, error): super().__init__("error={}", error)
[docs] class APINotFound(pykern.util.APIError): """Raised for an object not found""" def __init__(self, api_name): super().__init__("api_name={}", api_name)
[docs] class APIProtocolError(pykern.util.APIError): """Raised when protocol error at lower level""" def __init__(self, error): super().__init__("error={}", error)
[docs] def is_subscription(func): """Is `func` a subscription api? Args: func (function): class api Returns: bool: True if is subscription api """ return getattr(func, _SUBSCRIPTION_ATTR, False)
[docs] def msg_pack(unserialized): """Used by client and server, not public""" def _default(obj): if isinstance(obj, datetime.datetime): return int(obj.timestamp()) if isinstance(obj, enum.Enum): return obj.value if hasattr(obj, "tolist"): # tolist works with pandas and numpy. If tolist takes # params or not a callable, then the result will be the # essentially the same as not having this code. return obj.tolist() return obj p = msgpack.Packer(autoreset=False, default=_default) p.pack(unserialized) # TODO(robnagler) getbuffer() would be better return p.bytes()
[docs] def msg_unpack(serialized, which): """Used by client and server, not public""" def _int(rv, which): i = rv[which] if not isinstance(i, int): return None, f"msg {which} non-integer type={type(i)}" if i <= 0: return None, f"msg {which} non-positive int={i}" return None def _kind(rv): if r := _int(rv, "msg_kind"): return r try: k = MsgKind(rv.msg_kind) if k not in _MSG_KIND_IS_VALID[which]: return None, f"{k} invalid for {which}" rv.msg_kind = k return None except Exception as e: return None, f"msg_kind={rv.msg_kind} not in valid" try: u = msgpack.Unpacker( object_hook=pykern.pkcollections.object_pairs_hook, ) u.feed(serialized) rv = u.unpack() except Exception as e: return None, f"msgpack exception={e}" if not isinstance(rv, PKDict): return None, f"msg not dict type={type(rv)}" for f in "call_id", "msg_kind": if not rv.get(f): return None, f"msg missing {f} keys={list(rv.keys())}" return _int(rv, "call_id") or _kind(rv) or (rv, None)
[docs] def subscription(func): """Decorator for api functions thhat can be subscribed by clients. Args: func (function): class api Returns: function: function to use """ # Give some early feedback if not inspect.iscoroutinefunction(func): raise AssertionError(f"func={func.__name__} must be a coroutine") setattr(func, _SUBSCRIPTION_ATTR, True) return func