diff --git a/juju/_sync.py b/juju/_sync.py new file mode 100644 index 00000000..b2a32cb9 --- /dev/null +++ b/juju/_sync.py @@ -0,0 +1,108 @@ +# Copyright 2024 Canonical Ltd. +# Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations + +import asyncio +import dataclasses +import functools +import logging +import threading +from typing import ( + Any, + Callable, + Coroutine, + Generic, + TypeVar, +) + +from typing_extensions import Self + +import juju.client.connection +import juju.model + +R = TypeVar("R") + + +@dataclasses.dataclass +class SyncCacheLine(Generic[R]): + value: R | None + exception: Exception | None + + +def cache_until_await(f: Callable[..., R]) -> Callable[..., R]: + @functools.wraps(f) + def inner(self: juju.model.ModelEntity, *args, **kwargs) -> R: + try: + assert isinstance(self, juju.model.ModelEntity) + cached: SyncCacheLine[R] = self._sync_cache.setdefault( + f.__name__, + SyncCacheLine(None, None), + ) + + if cached.value is None and cached.exception is None: + asyncio.get_running_loop().call_soon(self._sync_cache.clear) + try: + cached.value = f(self, *args, **kwargs) + except Exception as e: + cached.exception = e + + if cached.exception: + raise cached.exception + + assert cached.value is not None + return cached.value + except AttributeError as e: + # The decorated functions are commonly used in @property's + # where the class or base class declares __getattr__ too. + # Python data model has is that AttributeError is special + # in this case, so wrap it into something else. + raise Exception(repr(e)) from e + + return inner + + +class ThreadedAsyncRunner(threading.Thread): + _conn: juju.client.connection.Connection | None + _loop: asyncio.AbstractEventLoop + + @classmethod + def new_connected(cls, *, connection_kwargs: dict[str, Any]) -> Self: + rv = cls() + rv.start() + try: + rv._conn = asyncio.run_coroutine_threadsafe( + juju.client.connection.Connection.connect(**connection_kwargs), # type: ignore[reportUnknownMemberType] + rv._loop, + ).result() + return rv + except Exception: + logging.exception("Helper thread failed to connect") + # TODO: .stop vs .close + rv._loop.stop() + rv.join() + raise + + def call(self, coro: Coroutine[None, None, R]) -> R: + return asyncio.run_coroutine_threadsafe(coro, self._loop).result() + + def stop(self) -> None: + if self._conn: + self.call(self._conn.close()) + self._conn = None + self._loop.call_soon_threadsafe(self._loop.stop) + self.join() + + @property + def connection(self) -> juju.client.connection.Connection: + assert self._conn + return self._conn + + def __init__(self) -> None: + super().__init__() + self._conn = None + self._loop = asyncio.new_event_loop() + + def run(self) -> None: + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + self._loop.close() diff --git a/juju/application.py b/juju/application.py index b5efd172..5f5a2576 100644 --- a/juju/application.py +++ b/juju/application.py @@ -9,11 +9,18 @@ from pathlib import Path from typing_extensions import deprecated +from typing_extensions import reveal_type as reveal_type # FIXME temp from . import model, tag, utils +from ._sync import cache_until_await from .annotationhelper import _get_annotations, _set_annotations from .bundle import get_charm_series, is_local_charm from .client import _definitions, client +from .client._definitions import ( + ApplicationGetResults, + ApplicationResult, + Value, +) from .errors import JujuApplicationConfigError, JujuError from .origin import Channel from .placement import parse as parse_placement @@ -44,7 +51,9 @@ def name(self) -> str: @property def exposed(self) -> bool: - return self.safe_data["exposed"] + rv = self._application_info().exposed + assert rv is not None + return rv @property @deprecated("Application.owner_tag is deprecated and will be removed in v4") @@ -60,9 +69,22 @@ def life(self) -> str: def min_units(self) -> int: return self.safe_data["min-units"] + # Well, this attribute is lovely: + # - not used in integration tests, as far as I can see + # - not used in zaza*tests + # - not used in openstack upgrader + # - no unit tests in this repo + # - no integration tests in this repo + # Why was it here in the first place? + # @property + # def constraints(self) -> dict[str, str | int | bool]: + # return FIXME_to_dict(self.constraints_object) + @property - def constraints(self) -> dict[str, str | int | bool]: - return self.safe_data["constraints"] + def constraints_object(self) -> Value: + rv = self._application_get().constraints + assert isinstance(rv, Value) # FIXME #1249 + return rv @property @deprecated("Application.subordinate is deprecated and will be removed in v4") @@ -76,6 +98,28 @@ def subordinate(self) -> bool: def workload_version(self) -> str: return self.safe_data["workload-version"] + @cache_until_await + def _application_get(self) -> ApplicationGetResults: + return self.model._sync_call( + self.model._sync_application_facade.Get( + application=self.name, + ) + ) + + @cache_until_await + def _application_info(self) -> ApplicationResult: + first = self.model._sync_call( + self.model._sync_application_facade.ApplicationsInfo( + entities=[client.Entity(self.tag)], + ) + ).results[0] + # This API can get a bunch of results for a bunch of entities, or "tags" + # For each, either .result or .error is set by Juju, and an exception is + # raised on any .error by juju.client.connection.Connection.rpc() + assert first + assert first.result + return first.result + @property def _unit_match_pattern(self): return rf"^{self.entity_id}.*$" @@ -643,7 +687,9 @@ def charm_name(self) -> str: :return str: The name of the charm """ - return URL.parse(self.safe_data["charm-url"]).name + rv = self._application_get().charm + assert isinstance(rv, str) # FIXME #1249 + return rv @property @deprecated("Application.charm_url is deprecated and will be removed in v4") diff --git a/juju/client/protocols.py b/juju/client/protocols.py new file mode 100644 index 00000000..0d689cea --- /dev/null +++ b/juju/client/protocols.py @@ -0,0 +1,34 @@ +# Copyright 2025 Canonical Ltd. +# Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations + +from typing import Protocol + +from juju.client._definitions import ( + ApplicationGetResults, + ApplicationInfoResults, + Entity, +) + + +class ApplicationFacadeProtocol(Protocol): + async def Get(self, application=None, branch=None) -> ApplicationGetResults: ... # noqa: N802 + + # jRRC Params={"entities":[{"tag": "yada-yada"}]} + # codegen unpacks top-level keys into keyword arguments + async def ApplicationsInfo( # noqa: N802 + self, entities: list[Entity] + ) -> ApplicationInfoResults: ... + + # etc... + # etc... + # etc... + # etc... + # etc... + # etc... + + +class CharmsFacadeProtocol(Protocol): ... + + +class UniterFacadeProtocol(Protocol): ... diff --git a/juju/model/__init__.py b/juju/model/__init__.py index 826cff88..b1dfb74d 100644 --- a/juju/model/__init__.py +++ b/juju/model/__init__.py @@ -23,17 +23,28 @@ from datetime import datetime, timedelta from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping, overload +from typing import ( + TYPE_CHECKING, + Any, + Coroutine, + Iterable, + Literal, + Mapping, + TypeVar, + overload, +) import websockets import yaml from typing_extensions import deprecated from .. import provisioner, tag, utils +from .._sync import SyncCacheLine as SyncCacheLine +from .._sync import ThreadedAsyncRunner from ..annotationhelper import _get_annotations, _set_annotations from ..bundle import BundleHandler, get_charm_series, is_local_charm from ..charmhub import CharmHub -from ..client import client, connection, connector +from ..client import client, connection, connector, protocols from ..client._definitions import ApplicationStatus as ApplicationStatus from ..client._definitions import MachineStatus as MachineStatus from ..client._definitions import UnitStatus as UnitStatus @@ -76,6 +87,8 @@ from ..remoteapplication import ApplicationOffer, RemoteApplication from ..unit import Unit +R = TypeVar("R") + log = logger = logging.getLogger(__name__) @@ -645,6 +658,7 @@ class Model: connector: connector.Connector state: ModelState + _sync: ThreadedAsyncRunner | None = None def __init__( self, @@ -686,6 +700,28 @@ def __init__( Schema.CHARM_HUB: CharmhubDeployType(self._resolve_charm), } + def _sync_call(self, coro: Coroutine[None, None, R]) -> R: + assert self._sync + return self._sync.call(coro) + + @property + def _sync_application_facade(self) -> protocols.ApplicationFacadeProtocol: + """An ApplicationFacade suitable for ._sync.call(...)""" + assert self._sync + return client.ApplicationFacade.from_connection(self._sync.connection) + + @property + def _sync_charms_facade(self) -> protocols.CharmsFacadeProtocol: + assert self._sync + return client.CharmsFacade.from_connection(self._sync.connection) + + # FIXME uniter facade is gone now... I hope it was not needed + # @property + # def _sync_uniter_facade(self) -> protocols.UniterFacadeProtocol: + # """A UniterFacade suitable for ._sync.call(...)""" + # assert self._sync + # return client.UniterFacade.from_connection(self._sync.connection) + def is_connected(self): """Reports whether the Model is currently connected.""" return self._connector.is_connected() @@ -809,6 +845,10 @@ async def connect(self, *args, **kwargs): if not is_debug_log_conn: await self._after_connect(model_name, model_uuid) + self._sync = ThreadedAsyncRunner.new_connected( + connection_kwargs=self._connector._kwargs_cache + ) + async def connect_model(self, model_name, **kwargs): """.. deprecated:: 0.6.2 Use ``connect(model_name=model_name)`` instead.