Skip to content

Commit

Permalink
chore: copy implementation from juju#1104
Browse files Browse the repository at this point in the history
  • Loading branch information
dimaqq committed Jan 7, 2025
1 parent 35ba68b commit c961baf
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 6 deletions.
108 changes: 108 additions & 0 deletions juju/_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2024 Canonical Ltd.
# Licensed under the Apache V2, see LICENCE file for details.
from __future__ import annotations

import asyncio
import dataclasses
import functools
import logging
import threading
from typing import (
Any,
Callable,
Coroutine,
Generic,
TypeVar,
)

from typing_extensions import Self

import juju.client.connection
import juju.model

R = TypeVar("R")


@dataclasses.dataclass
class SyncCacheLine(Generic[R]):
value: R | None
exception: Exception | None


def cache_until_await(f: Callable[..., R]) -> Callable[..., R]:
@functools.wraps(f)
def inner(self: juju.model.ModelEntity, *args, **kwargs) -> R:
try:
assert isinstance(self, juju.model.ModelEntity)
cached: SyncCacheLine[R] = self._sync_cache.setdefault(
f.__name__,
SyncCacheLine(None, None),
)

if cached.value is None and cached.exception is None:
asyncio.get_running_loop().call_soon(self._sync_cache.clear)
try:
cached.value = f(self, *args, **kwargs)
except Exception as e:
cached.exception = e

if cached.exception:
raise cached.exception

assert cached.value is not None
return cached.value
except AttributeError as e:
# The decorated functions are commonly used in @property's
# where the class or base class declares __getattr__ too.
# Python data model has is that AttributeError is special
# in this case, so wrap it into something else.
raise Exception(repr(e)) from e

return inner


class ThreadedAsyncRunner(threading.Thread):
_conn: juju.client.connection.Connection | None
_loop: asyncio.AbstractEventLoop

@classmethod
def new_connected(cls, *, connection_kwargs: dict[str, Any]) -> Self:
rv = cls()
rv.start()
try:
rv._conn = asyncio.run_coroutine_threadsafe(
juju.client.connection.Connection.connect(**connection_kwargs), # type: ignore[reportUnknownMemberType]
rv._loop,
).result()
return rv
except Exception:
logging.exception("Helper thread failed to connect")
# TODO: .stop vs .close
rv._loop.stop()
rv.join()
raise

def call(self, coro: Coroutine[None, None, R]) -> R:
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()

def stop(self) -> None:
if self._conn:
self.call(self._conn.close())
self._conn = None
self._loop.call_soon_threadsafe(self._loop.stop)
self.join()

@property
def connection(self) -> juju.client.connection.Connection:
assert self._conn
return self._conn

def __init__(self) -> None:
super().__init__()
self._conn = None
self._loop = asyncio.new_event_loop()

def run(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
self._loop.close()
54 changes: 50 additions & 4 deletions juju/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@
from pathlib import Path

from typing_extensions import deprecated
from typing_extensions import reveal_type as reveal_type # FIXME temp

from . import model, tag, utils
from ._sync import cache_until_await
from .annotationhelper import _get_annotations, _set_annotations
from .bundle import get_charm_series, is_local_charm
from .client import _definitions, client
from .client._definitions import (
ApplicationGetResults,
ApplicationResult,
Value,
)
from .errors import JujuApplicationConfigError, JujuError
from .origin import Channel
from .placement import parse as parse_placement
Expand Down Expand Up @@ -44,7 +51,9 @@ def name(self) -> str:

@property
def exposed(self) -> bool:
return self.safe_data["exposed"]
rv = self._application_info().exposed
assert rv is not None
return rv

@property
@deprecated("Application.owner_tag is deprecated and will be removed in v4")
Expand All @@ -60,9 +69,22 @@ def life(self) -> str:
def min_units(self) -> int:
return self.safe_data["min-units"]

# Well, this attribute is lovely:
# - not used in integration tests, as far as I can see
# - not used in zaza*tests
# - not used in openstack upgrader
# - no unit tests in this repo
# - no integration tests in this repo
# Why was it here in the first place?
# @property
# def constraints(self) -> dict[str, str | int | bool]:
# return FIXME_to_dict(self.constraints_object)

@property
def constraints(self) -> dict[str, str | int | bool]:
return self.safe_data["constraints"]
def constraints_object(self) -> Value:
rv = self._application_get().constraints
assert isinstance(rv, Value) # FIXME #1249
return rv

@property
@deprecated("Application.subordinate is deprecated and will be removed in v4")
Expand All @@ -76,6 +98,28 @@ def subordinate(self) -> bool:
def workload_version(self) -> str:
return self.safe_data["workload-version"]

@cache_until_await
def _application_get(self) -> ApplicationGetResults:
return self.model._sync_call(
self.model._sync_application_facade.Get(
application=self.name,
)
)

@cache_until_await
def _application_info(self) -> ApplicationResult:
first = self.model._sync_call(
self.model._sync_application_facade.ApplicationsInfo(
entities=[client.Entity(self.tag)],
)
).results[0]
# This API can get a bunch of results for a bunch of entities, or "tags"
# For each, either .result or .error is set by Juju, and an exception is
# raised on any .error by juju.client.connection.Connection.rpc()
assert first
assert first.result
return first.result

@property
def _unit_match_pattern(self):
return rf"^{self.entity_id}.*$"
Expand Down Expand Up @@ -643,7 +687,9 @@ def charm_name(self) -> str:
:return str: The name of the charm
"""
return URL.parse(self.safe_data["charm-url"]).name
rv = self._application_get().charm
assert isinstance(rv, str) # FIXME #1249
return rv

@property
@deprecated("Application.charm_url is deprecated and will be removed in v4")
Expand Down
34 changes: 34 additions & 0 deletions juju/client/protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2025 Canonical Ltd.
# Licensed under the Apache V2, see LICENCE file for details.
from __future__ import annotations

from typing import Protocol

from juju.client._definitions import (
ApplicationGetResults,
ApplicationInfoResults,
Entity,
)


class ApplicationFacadeProtocol(Protocol):
async def Get(self, application=None, branch=None) -> ApplicationGetResults: ... # noqa: N802

# jRRC Params={"entities":[{"tag": "yada-yada"}]}
# codegen unpacks top-level keys into keyword arguments
async def ApplicationsInfo( # noqa: N802
self, entities: list[Entity]
) -> ApplicationInfoResults: ...

# etc...
# etc...
# etc...
# etc...
# etc...
# etc...


class CharmsFacadeProtocol(Protocol): ...


class UniterFacadeProtocol(Protocol): ...
44 changes: 42 additions & 2 deletions juju/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,28 @@
from datetime import datetime, timedelta
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping, overload
from typing import (
TYPE_CHECKING,
Any,
Coroutine,
Iterable,
Literal,
Mapping,
TypeVar,
overload,
)

import websockets
import yaml
from typing_extensions import deprecated

from .. import provisioner, tag, utils
from .._sync import SyncCacheLine as SyncCacheLine
from .._sync import ThreadedAsyncRunner
from ..annotationhelper import _get_annotations, _set_annotations
from ..bundle import BundleHandler, get_charm_series, is_local_charm
from ..charmhub import CharmHub
from ..client import client, connection, connector
from ..client import client, connection, connector, protocols
from ..client._definitions import ApplicationStatus as ApplicationStatus
from ..client._definitions import MachineStatus as MachineStatus
from ..client._definitions import UnitStatus as UnitStatus
Expand Down Expand Up @@ -76,6 +87,8 @@
from ..remoteapplication import ApplicationOffer, RemoteApplication
from ..unit import Unit

R = TypeVar("R")

log = logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -645,6 +658,7 @@ class Model:

connector: connector.Connector
state: ModelState
_sync: ThreadedAsyncRunner | None = None

def __init__(
self,
Expand Down Expand Up @@ -686,6 +700,28 @@ def __init__(
Schema.CHARM_HUB: CharmhubDeployType(self._resolve_charm),
}

def _sync_call(self, coro: Coroutine[None, None, R]) -> R:
assert self._sync
return self._sync.call(coro)

@property
def _sync_application_facade(self) -> protocols.ApplicationFacadeProtocol:
"""An ApplicationFacade suitable for ._sync.call(...)"""
assert self._sync
return client.ApplicationFacade.from_connection(self._sync.connection)

@property
def _sync_charms_facade(self) -> protocols.CharmsFacadeProtocol:
assert self._sync
return client.CharmsFacade.from_connection(self._sync.connection)

# FIXME uniter facade is gone now... I hope it was not needed
# @property
# def _sync_uniter_facade(self) -> protocols.UniterFacadeProtocol:
# """A UniterFacade suitable for ._sync.call(...)"""
# assert self._sync
# return client.UniterFacade.from_connection(self._sync.connection)

def is_connected(self):
"""Reports whether the Model is currently connected."""
return self._connector.is_connected()
Expand Down Expand Up @@ -809,6 +845,10 @@ async def connect(self, *args, **kwargs):
if not is_debug_log_conn:
await self._after_connect(model_name, model_uuid)

self._sync = ThreadedAsyncRunner.new_connected(
connection_kwargs=self._connector._kwargs_cache
)

async def connect_model(self, model_name, **kwargs):
""".. deprecated:: 0.6.2
Use ``connect(model_name=model_name)`` instead.
Expand Down

0 comments on commit c961baf

Please sign in to comment.