Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite: APNs: Scoped App Tokens #101

Merged
merged 19 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/workflows/pyright.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Pyright
on: [push, pull_request]
jobs:
pyright:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
cache: 'pip'

- run: |
python -m venv .venv
source .venv/bin/activate
pip install -e '.[test,cli]'

- run: echo "$PWD/.venv/bin" >> $GITHUB_PATH
- uses: jakebailey/pyright-action@v2
8 changes: 8 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Ruff
on: [push, pull_request]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,8 @@ version_file = "pypush/_version.py"
[tool.pytest.ini_options]
minversion = "6.0"
addopts = ["-ra", "-q"]
testpaths = ["tests"]
testpaths = ["tests"]

[tool.ruff.lint]
select = ["E", "F", "B", "SIM", "I"]
ignore = ["E501", "B010"]
6 changes: 3 additions & 3 deletions pypush/apns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__all__ = ["protocol", "create_apns_connection", "activate"]
__all__ = ["protocol", "create_apns_connection", "activate", "filters"]

from . import protocol
from .lifecycle import create_apns_connection
from . import filters, protocol
from .albert import activate
from .lifecycle import create_apns_connection
14 changes: 7 additions & 7 deletions pypush/apns/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from dataclasses import MISSING, field
from dataclasses import fields as dataclass_fields
from typing import Any, TypeVar, get_origin, get_args, Union
from typing import Any, TypeVar, Union, get_args, get_origin

from pypush.apns.transport import Packet

Expand Down Expand Up @@ -67,14 +67,14 @@ def from_packet(cls, packet: Packet):
)

# Check for extra fields
for field in packet.fields:
if field.id not in [
for current_field in packet.fields:
if current_field.id not in [
f.metadata["packet_id"]
for f in dataclass_fields(cls)
if f.metadata is not None and "packet_id" in f.metadata
]:
logging.warning(
f"Unexpected field with packet ID {field.id} in packet {packet}"
f"Unexpected field with packet ID {current_field.id} in packet {packet}"
)
return cls(**field_values)

Expand Down Expand Up @@ -122,15 +122,15 @@ def fid(
:param byte_len: The length of the field in bytes (for int fields)
:param default: The default value of the field
"""
if not default == MISSING and not default_factory == MISSING:
if default != MISSING and default_factory != MISSING:
raise ValueError("Cannot specify both default and default_factory")
if not default == MISSING:
if default != MISSING:
return field(
metadata={"packet_id": packet_id, "packet_bytes": byte_len},
default=default,
repr=repr,
)
if not default_factory == MISSING:
if default_factory != MISSING:
return field(
metadata={"packet_id": packet_id, "packet_bytes": byte_len},
default_factory=default_factory,
Expand Down
50 changes: 45 additions & 5 deletions pypush/apns/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,72 @@
from typing import Generic, TypeVar

import anyio
from anyio.abc import ObjectSendStream
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from . import filters

T = TypeVar("T")


class BroadcastStream(Generic[T]):
def __init__(self):
def __init__(self, backlog: int = 50):
self.streams: list[ObjectSendStream[T]] = []
self.backlog: list[T] = []
self._backlog_size = backlog

async def broadcast(self, packet):
logging.debug(f"Broadcasting {packet} to {len(self.streams)} streams")
for stream in self.streams:
try:
await stream.send(packet)
except anyio.BrokenResourceError:
self.streams.remove(stream)
logging.error("Broken resource error")
# self.streams.remove(stream)
# If we have a backlog, add the packet to it
if len(self.backlog) >= self._backlog_size:
self.backlog.pop(0)
self.backlog.append(packet)

@asynccontextmanager
async def open_stream(self):
send, recv = anyio.create_memory_object_stream[T]()
async def open_stream(self, backlog: bool = True):
# 1000 seems like a reasonable number, if more than 1000 messages come in before someone deals with them it will
# start stalling the APNs connection itself
send, recv = anyio.create_memory_object_stream[T](max_buffer_size=1000)
if backlog:
for packet in self.backlog:
await send.send(packet)
self.streams.append(send)
async with recv:
yield recv
self.streams.remove(send)
await send.aclose()


W = TypeVar("W")
F = TypeVar("F")


class FilteredStream(ObjectReceiveStream[F]):
"""
A stream that filters out unwanted items

filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it
"""

def __init__(self, source: ObjectReceiveStream[W], filter: filters.Filter[W, F]):
self.source = source
self.filter = filter

async def receive(self) -> F:
async for item in self.source:
if (filtered := self.filter(item)) is not None:
return filtered
raise anyio.EndOfStream

async def aclose(self):
await self.source.aclose()


def exponential_backoff(f):
async def wrapper(*args, **kwargs):
backoff = 1
Expand Down
6 changes: 3 additions & 3 deletions pypush/apns/albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import uuid
from base64 import b64decode
from typing import Tuple, Optional
from typing import Optional, Tuple

import httpx
from cryptography import x509
Expand Down Expand Up @@ -96,10 +96,10 @@ async def activate(

try:
protocol = re.search("<Protocol>(.*)</Protocol>", resp.text).group(1) # type: ignore
except AttributeError:
except AttributeError as e:
# Search for error text between <b> and </b>
error = re.search("<b>(.*)</b>", resp.text).group(1) # type: ignore
raise Exception(f"Failed to get certificate from Albert: {error}")
raise Exception(f"Failed to get certificate from Albert: {error}") from e

protocol = plistlib.loads(protocol.encode("utf-8"))

Expand Down
44 changes: 44 additions & 0 deletions pypush/apns/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
from typing import Callable, Optional, Type, TypeVar

from pypush.apns import protocol

T1 = TypeVar("T1")
T2 = TypeVar("T2")
Filter = Callable[[T1], Optional[T2]]

# Chain with proper types so that subsequent filters only need to take output type of previous filter
T_IN = TypeVar("T_IN", bound=protocol.Command)
T_MIDDLE = TypeVar("T_MIDDLE", bound=protocol.Command)
T_OUT = TypeVar("T_OUT", bound=protocol.Command)


def chain(first: Filter[T_IN, T_MIDDLE], second: Filter[T_MIDDLE, T_OUT]):
def filter(command: T_IN) -> Optional[T_OUT]:
logging.debug(f"Filtering {command} with {first} and {second}")
filtered = first(command)
if filtered is None:
return None
return second(filtered)

return filter


T = TypeVar("T", bound=protocol.Command)


def cmd(type: Type[T]) -> Filter[protocol.Command, T]:
def filter(command: protocol.Command) -> Optional[T]:
if isinstance(command, type):
return command
return None

return filter


def ALL(c):
return c


def NONE(_):
return None
Loading
Loading