Source code for pykern.api.server

"""WebSocket Quest server

:copyright: Copyright (c) 2025 RadiaSoft LLC.  All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""

from pykern.pkcollections import PKDict
from pykern.pkdebug import pkdc, pkdlog, pkdp, pkdexc
from pykern.api import util
import asyncio
import importlib
import inspect
import pykern.pkasyncio
import pykern.quest
import pykern.util
import tornado.websocket
import re


_API_NAME_RE = re.compile(rf"^{pykern.quest.API.METHOD_PREFIX}(\w+)")


[docs] class Session(pykern.quest.Attr): """State held on server bound to a client. Currently the state is not persisted when the server terminates. This may change. """ ATTR_KEY = "session" IS_SINGLETON = True
[docs] def handle_on_close(self): x = list(self.values()) # Reversed so LIFO while x: if s := getattr(x.pop(), "session_end", None): s()
[docs] class Subscription(pykern.quest.Attr): """EXPERIMENTAL""" ATTR_KEY = "subscription" def __init__(self, server_msg): super().__init__(None, _server_msg=server_msg)
[docs] def result_put(self, api_result): self._server_msg.subscription_result_put(api_result)
[docs] def start(api_classes, attr_classes, http_config, coros=()): """Start `_Server` in `pkasyncio` Args: api_classes (Iterable): `pykern.quest.API` subclasses to be dispatched attr_classes (Iterable): `pykern.quest.Attr` subclasses to create API instance http_config (PKDict): `pkasyncio.Loop.http_server` arg coros (Iterable): list of coroutines to be passed to `pkasyncio.Loop.run` """ l = pykern.pkasyncio.Loop() _Server(l, api_classes, attr_classes, http_config) if coros: l.run(*coros) l.start()
class _Server: def __init__(self, loop, api_classes, attr_classes, http_config): def _api_class_funcs(): a = False for c in api_classes: for r in _api_class_funcs1(c): if r.name == util.AUTH_API_NAME: a = True yield r if not a: for r in _api_class_funcs1( importlib.import_module("pykern.api.auth_api").AuthAPI ): yield r def _api_class_funcs1(clazz): for n, o in inspect.getmembers(clazz, predicate=inspect.isfunction): if not (m := _API_NAME_RE.search(n)): continue yield PKDict( class_=clazz, func=o, func_name=n, is_subscription=util.is_subscription(o), name=m.group(1), ) def _api_map(): rv = PKDict() for a in _api_class_funcs(): if a.name in rv: raise AssertionError( "duplicate api={a.name} class={a.class_.__name__}" ) # don't need to save func if not inspect.iscoroutinefunction(a.pkdel("func")): raise AssertionError( "api_func={n} is not async class={a.class_.__name__}" ) rv[a.name] = a return rv h = http_config.copy().pksetdefault(uri_map=[]) self.loop = loop self.api_map = _api_map() self.attr_classes = attr_classes h.uri_map = h.uri_map[:] h.uri_map.append((h.api_uri, _ServerHandler, PKDict(server=self))) self.api_uri = h.pkdel("api_uri") h.log_function = self._log_end self._ws_id = 0 loop.http_server(h) def handle_get(self, handler): self._log(handler, "ws-get") def handle_open(self, handler): try: self._ws_id += 1 handler.pykern_api_context.ws_id = self._ws_id return _ServerConnection(self, handler, ws_id=self._ws_id) except Exception as e: pkdlog("exception={} stack={}", e, pkdexc()) self._log(handler, "open-error", "exception={}", [e]) return None def _log(self, handler, which, fmt="", args=None): def _add(key, value): nonlocal f, a if value is not None: f += (" " if f else "") + key + "={}" a.append(value) f = "" a = [] if x := getattr(handler, "pykern_api_context", None): _add("error", x.pkdel("error")) _add("ws_id", x.get("ws_id")) if fmt: f = f + " " + fmt a.extend(args) self.loop.http_log(handler, which, f, a) def _log_end(self, handler, *args, **kwargs): if isinstance(handler, _ServerHandler): self._log(handler, "ws-end") else: self.loop.http_log(handler) class _ServerConnection: def __init__(self, server, handler, ws_id): self.server = server self.handler = handler self.ws_id = ws_id self.pending_msgs = [] self._destroyed = False self.session = Session(None) self.remote_peer = server.loop.remote_peer(handler.request) self.log("ws-open") def destroy(self): if self._destroyed: return self._destroyed = True x = list(self.pending_msgs) self.pending_msgs = [] while x: # Reversed so end in opposite order of calls (LIFO) x.pop().destroy() self.handler.close() self.handler = None if s := self.session: self.session = None # Last since it calls out of this module s.handle_on_close() def handle_on_close(self): self.destroy() async def handle_on_message(self, msg): if self._destroyed: return m = None try: m = _ServerMsg(self) self.pending_msgs.append(m) if not await m.process(msg): self.destroy() except Exception as e: pkdlog("exception={} stack={}", e, pkdexc()) self.log("msg-error", m, "unhandled exception={}", [e]) self.destroy() finally: try: if not self._destroyed and m: # Maybe have been already destroyed self.pending_msgs.remove(m) except Exception as e: pkdlog("exception={} stack={}", e, pkdexc()) pass def log(self, which, call=None, fmt="", args=None): if fmt: fmt = " " + fmt pkdlog( "{} ip={} ws={} {}" + fmt, which, self.remote_peer, self.ws_id, call, *(args if args else ()), ) class _ServerHandler(tornado.websocket.WebSocketHandler): def initialize(self, server): # Since part of a global space, need to prefix self.pykern_api_server = server self.pykern_api_context = PKDict() self.pykern_api_connection = None async def get(self, *args, **kwargs): try: self.pykern_api_server.handle_get(self) return await super().get(*args, **kwargs) except Exception as e: pkdlog("exception={} stack={}", e, pkdexc()) async def on_message(self, msg): try: # WebSocketHandler only allows one on_message at a time pykern.pkasyncio.create_task( self.pykern_api_connection.handle_on_message(msg) ) except Exception as e: pkdlog("exception={} stack={}", e, pkdexc()) def on_close(self): try: if not (c := self.pykern_api_connection): return self.pykern_api_connection = None c.handle_on_close() except Exception as e: pkdlog("exception={} stack={}", e, pkdexc()) def open(self): try: self.pykern_api_connection = self.pykern_api_server.handle_open(self) except Exception as e: pkdlog("exception={} stack={}", e, pkdexc()) class _ServerMsg: def __init__(self, connection): self._connection = connection self._call = None self._qcall = None self._api = None self._destroyed = False def destroy(self, unsubscribe=False): if self._destroyed: return self._destroyed = True if not (c := self._qcall): return self._qcall = None c.quest_end(in_error=not unsubscribe) async def process(self, msg): try: self._call, e = util.msg_unpack(msg, "server") if e: self._log("unpack-error", "error={}", [e]) return False if r := self._parse(): pass elif self._call.msg_kind.is_unsubscribe(): self._unsubscribe(self._call.call_id) return True elif self._call.msg_kind.is_subscribe(): r = await self._do_call(Subscription(self)) if r is not None and not isinstance(r, Exception): r = util.APICallError( f"return type={type(r)} from subscription api must be None" ) else: r = await self._do_call(None) if self._destroyed: return True self._reply(r) if isinstance(r, util.APIProtocolError): self._log("protocol-error", "exception={}", [e]) return False self._log("end") return True except Exception as e: pkdlog("exception={} {} stack={}", e, self, pkdexc()) self._log("process-error", "exception={}", [e]) self._reply(e) return False finally: self.destroy() def subscription_result_put(self, api_result): if isinstance(api_result, Exception): raise util.APICallError( "api_result type={type(api_result)} may not be an exception" ) if api_result is None: raise util.APICallError("api_result may not be None") self._reply(api_result) if self._destroyed: raise util.APIDisconnected() async def _do_call(self, sub): try: # Let quest.start see the exception with self._quest_start(sub) as c: try: self._qcall = c return await getattr(c, self._api.func_name)(self._call.api_args) finally: self._qcall = None except Exception as e: pkdlog("exception={} {} stack={}", self, e, pkdexc()) if not isinstance(e, pykern.util.APIError): e = util.APICallError(f"unhandled_exception={e}") return e def _log(self, which, *args): return self._connection.log(which, self, *args) def _parse(self): def _args(): if self._call.get("api_args") is None: return util.APIProtocolError("missing msg field api_args") return None def _name(): if not (n := self._call.get("api_name")): return util.APIProtocolError("missing msg field api_name") if a := self._connection.server.api_map.get(n): self._api = a return None return util.APINotFound(n) def _kind(): k = self._call.msg_kind if k.is_unsubscribe(): return None if r := _name(): return r if k.is_subscribe(): if not self._api.is_subscription: return util.APIKindError( f"cannot subscribe non-subscription api={self._call.api_name}" ) elif k.is_call(): if self._api.is_subscription: return util.APIKindError( f"call_api on subscription api={self._call.api_name}" ) else: raise AssertionError(f"invalid {k} returned from msg_unpack") return _args() self._log(self._call.msg_kind.name.lower()) return _kind() def _quest_start(self, sub=None): a = list(self._connection.server.attr_classes) a.append(self._connection.session) if sub: a.append(sub) return pykern.quest.start(self._api.class_, a) def _reply(self, call_rv): try: if call_rv == None: if self._call.msg_kind.is_subscribe(): r = PKDict(msg_kind=util.MsgKind.UNSUBSCRIBE) else: r = PKDict(api_result=None, api_error="missing reply") elif not isinstance(call_rv, Exception): r = PKDict(api_result=call_rv, api_error=None) elif isinstance(call_rv, pykern.util.APIError): r = PKDict( api_result=None, api_error=f"{call_rv.__class__.__name__}: {call_rv}", ) else: r = PKDict(api_result=None, api_error=f"unhandled_exception={call_rv}") r.pksetdefault(msg_kind=util.MsgKind.REPLY) r.call_id = self._call.call_id self._connection.handler.write_message(util.msg_pack(r), binary=True) self._log("reply") except Exception as e: pkdlog("exception={} {} stack={}", e, self, pkdexc()) self._log("reply-error") self.destroy() def __str__(self): def _destroyed(): return " DESTROYED" if self._destroyed else "" def _info(c): if not c: return k = c.get("msg_kind") i = c.get("call_id") if not (k and i): return rv = f"{k.name.lower()}#{i}" if a := c.get("api_name"): rv += " " + a return f"<{rv}{_destroyed()}>" return _info(self._call) or f"<{self.__class__.__name__}{_destroyed()}>" def _unsubscribe(self, call_id): for m in self._connection.pending_msgs: if m._call and m._call.get("call_id") == call_id and m._api: if m._api.is_subscription: m.destroy(unsubscribe=True)