From c7e8a6fa3dbf2d67bdc84e4f1db91044d598fe35 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Fri, 17 May 2024 23:41:30 -0400 Subject: [PATCH 01/19] apns: implement scoped app tokens --- pypush/apns/_util.py | 4 +++- pypush/apns/lifecycle.py | 51 ++++++++++++++++++++++++++++++---------- pypush/apns/protocol.py | 21 +++++++++++++++++ pypush/apns/transport.py | 2 ++ tests/test_apns.py | 15 ++++++++++++ 5 files changed, 79 insertions(+), 14 deletions(-) diff --git a/pypush/apns/_util.py b/pypush/apns/_util.py index 09e9574..49e7a84 100644 --- a/pypush/apns/_util.py +++ b/pypush/apns/_util.py @@ -13,11 +13,13 @@ def __init__(self): self.streams: list[ObjectSendStream[T]] = [] async def broadcast(self, packet): + logging.debug(f"Broadcasting {packet} to {len(self.streams)} streams") for stream in self.streams: try: await stream.send(packet) except anyio.BrokenResourceError: - self.streams.remove(stream) + logging.error("Broken resource error") + #self.streams.remove(stream) @asynccontextmanager async def open_stream(self): diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 49b4fcf..61bedfd 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -6,6 +6,7 @@ import time import typing from contextlib import asynccontextmanager +from hashlib import sha1 import anyio from anyio.abc import TaskGroup @@ -44,6 +45,8 @@ def __init__( self.private_key = private_key self.base_token = token + self.connected = anyio.Event() + self._conn = None self._tg = task_group self._broadcast = _util.BroadcastStream[protocol.Command]() @@ -74,6 +77,8 @@ async def _ping_task(self): @_util.exponential_backoff async def reconnect(self): async with self._reconnect_lock: # Prevent weird situations where multiple reconnects are happening at once + if self.connected.is_set(): + self.connected = anyio.Event() if self._conn is not None: logging.warning("Closing existing connection") await self._conn.aclose() @@ -93,7 +98,7 @@ async def reconnect(self): protocol.ConnectCommand( push_token=self.base_token, state=1, - flags=69, + flags=65, #69 certificate=cert, nonce=nonce, signature=signature, @@ -107,6 +112,7 @@ async def reconnect(self): self.base_token = ack.token else: assert ack.token == self.base_token + self.connected.set() async def aclose(self): if self._conn is not None: @@ -115,27 +121,46 @@ async def aclose(self): T = typing.TypeVar("T", bound=protocol.Command) - async def receive_stream( - self, filter: typing.Type[T], max: int = -1 - ) -> typing.AsyncIterator[T]: + # async def receive_stream( + # self, filter: typing.Type[T], max: int = -1 + # ) -> typing.AsyncIterator[T]: + # async with self._broadcast.open_stream() as stream: + # async for command in stream: + # if isinstance(command, filter): + # max -= 1 + # yield command + # if max == 0: + # break + # logging.error("Stream ended") # BUG: Will never happen, async iterators don't autoclose + + async def receive(self, filter: typing.Type[T]) -> T: async with self._broadcast.open_stream() as stream: async for command in stream: if isinstance(command, filter): - yield command - max -= 1 - if max == 0: - break - - async def receive(self, filter: typing.Type[T]) -> T: - async for command in self.receive_stream(filter, 1): - return command - raise ValueError("No matching command received") + return command + raise ValueError("Did not receive expected command") async def send(self, command: protocol.Command): try: assert self._conn is not None + if not self.connected.is_set(): + await self.connected.wait() await self._conn.send(command) except Exception as e: logging.warning(f"Error sending command, reconnecting") await self.reconnect() await self.send(command) + + async def filter(self, topics: list[str]): + await self.connected.wait() + assert self.base_token is not None + await self.send(protocol.FilterCommand(token=self.base_token, enabled_topic_hashes=[sha1(topic.encode()).digest() for topic in topics])) + + async def request_scoped_token(self, topic: str) -> bytes: + topic_hash = sha1(topic.encode()).digest() + await self.connected.wait() # Need to wait for connection so that base token is set + assert self.base_token is not None + await self.send(protocol.ScopedTokenCommand(token=self.base_token, topic=topic_hash)) + ack = await self.receive(protocol.ScopedTokenAck) + assert ack.status == 0 + return ack.scoped_token diff --git a/pypush/apns/protocol.py b/pypush/apns/protocol.py index ea0f7d3..ce04998 100644 --- a/pypush/apns/protocol.py +++ b/pypush/apns/protocol.py @@ -228,6 +228,25 @@ class SendMessageAck(Command): token: Optional[bytes] = fid(1, default=None) unknown6: Optional[bytes] = fid(6, default=None) +@command +@dataclass +class ScopedTokenCommand(Command): + PacketType = Packet.Type.ScopedToken + + token: bytes = fid(1) + topic: bytes = fid(2) + app_id: Optional[bytes] = fid(3, default=None) + +@command +@dataclass +class ScopedTokenAck(Command): + PacketType = Packet.Type.ScopedTokenAck + + status: int = fid(1) + scoped_token: bytes = fid(2) + topic: bytes = fid(3) + app_id: Optional[bytes] = fid(4, default=None) + @dataclass class UnknownCommand(Command): @@ -259,6 +278,8 @@ def command_from_packet(packet: Packet) -> Command: Packet.Type.SetState: SetStateCommand, Packet.Type.SendMessage: SendMessageCommand, Packet.Type.SendMessageAck: SendMessageAck, + Packet.Type.ScopedToken: ScopedTokenCommand, + Packet.Type.ScopedTokenAck: ScopedTokenAck, # Add other mappings here... } command_class = command_classes.get(packet.id, None) diff --git a/pypush/apns/transport.py b/pypush/apns/transport.py index 864f3eb..2b8c40c 100644 --- a/pypush/apns/transport.py +++ b/pypush/apns/transport.py @@ -30,6 +30,8 @@ class Type(Enum): KeepAlive = 12 KeepAliveAck = 13 NoStorage = 14 + ScopedToken = 17 + ScopedTokenAck = 18 SetState = 20 UNKNOWN = "Unknown" diff --git a/tests/test_apns.py b/tests/test_apns.py index 3b24508..a49f924 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -40,3 +40,18 @@ async def test_shorthand(): *await apns.activate(), courier="localhost" ) as connection: await connection.receive(apns.protocol.ConnectAck) + +@pytest.mark.asyncio +async def test_scoped_token(): + async with apns.create_apns_connection( + *await apns.activate(), courier="1-courier.sandbox.push.apple.com" + ) as connection: + token = await connection.request_scoped_token("dev.jjtech.pypush.tests") + logging.warning(f"Got token: {token.hex()}") + await connection.filter(["dev.jjtech.pypush.tests"]) + logging.warning(f"waiting on topic 'dev.jjtech.pypush.tests'") + async with connection._broadcast.open_stream() as stream: + async for command in stream: + if isinstance(command, apns.protocol.SendMessageCommand) and command.topic == "dev.jjtech.pypush.tests" and command.token == token: + logging.warning(f"Got message: {command.payload.decode()}") + break From f543014c8fd4b89fdc9379496eaf3552e05aa62b Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sat, 18 May 2024 11:21:39 -0400 Subject: [PATCH 02/19] apns: lifecycle: await first connection event before yielding --- pypush/apns/lifecycle.py | 15 ++++++--------- tests/test_apns.py | 7 ++----- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 61bedfd..cd3ae29 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -25,7 +25,9 @@ async def create_apns_connection( courier: typing.Optional[str] = None, ): async with anyio.create_task_group() as tg: - conn = Connection(tg, certificate, private_key, token, courier) + conn = Connection(tg, certificate, private_key, token, courier)\ + # Await connected for first time here, so that base token is set + await conn._connected.wait() yield conn tg.cancel_scope.cancel() # Cancel the task group when the context manager exits await conn.aclose() # Make sure to close the connection after the task group is cancelled @@ -45,7 +47,7 @@ def __init__( self.private_key = private_key self.base_token = token - self.connected = anyio.Event() + self._connected = anyio.Event() # Set when the connection is first established self._conn = None self._tg = task_group @@ -77,8 +79,6 @@ async def _ping_task(self): @_util.exponential_backoff async def reconnect(self): async with self._reconnect_lock: # Prevent weird situations where multiple reconnects are happening at once - if self.connected.is_set(): - self.connected = anyio.Event() if self._conn is not None: logging.warning("Closing existing connection") await self._conn.aclose() @@ -112,7 +112,8 @@ async def reconnect(self): self.base_token = ack.token else: assert ack.token == self.base_token - self.connected.set() + if not self._connected.is_set(): + self._connected.set() async def aclose(self): if self._conn is not None: @@ -143,8 +144,6 @@ async def receive(self, filter: typing.Type[T]) -> T: async def send(self, command: protocol.Command): try: assert self._conn is not None - if not self.connected.is_set(): - await self.connected.wait() await self._conn.send(command) except Exception as e: logging.warning(f"Error sending command, reconnecting") @@ -152,13 +151,11 @@ async def send(self, command: protocol.Command): await self.send(command) async def filter(self, topics: list[str]): - await self.connected.wait() assert self.base_token is not None await self.send(protocol.FilterCommand(token=self.base_token, enabled_topic_hashes=[sha1(topic.encode()).digest() for topic in topics])) async def request_scoped_token(self, topic: str) -> bytes: topic_hash = sha1(topic.encode()).digest() - await self.connected.wait() # Need to wait for connection so that base token is set assert self.base_token is not None await self.send(protocol.ScopedTokenCommand(token=self.base_token, topic=topic_hash)) ack = await self.receive(protocol.ScopedTokenAck) diff --git a/tests/test_apns.py b/tests/test_apns.py index a49f924..4978fe8 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -29,17 +29,14 @@ async def test_lifecycle_2(): async with apns.create_apns_connection( certificate, key, courier="localhost" ) as connection: - await connection.receive( - apns.protocol.ConnectAck - ) # Just wait until the initial connection is established. Don't do this in real code plz. - + pass @pytest.mark.asyncio async def test_shorthand(): async with apns.create_apns_connection( *await apns.activate(), courier="localhost" ) as connection: - await connection.receive(apns.protocol.ConnectAck) + pass @pytest.mark.asyncio async def test_scoped_token(): From 7782fb8dad7d629df82b602e79cbfe48f0c527f3 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sat, 18 May 2024 12:07:42 -0400 Subject: [PATCH 03/19] apns: tests: add app notification test --- pypush/apns/_util.py | 2 +- pypush/apns/lifecycle.py | 22 +++++-- pypush/apns/protocol.py | 7 ++- pypush/apns/transport.py | 5 +- tests/assets/dev.jjtech.pypush.tests.pem | 75 ++++++++++++++++++++++++ tests/test_apns.py | 62 +++++++++++++++----- 6 files changed, 145 insertions(+), 28 deletions(-) create mode 100644 tests/assets/dev.jjtech.pypush.tests.pem diff --git a/pypush/apns/_util.py b/pypush/apns/_util.py index 49e7a84..c98d8d0 100644 --- a/pypush/apns/_util.py +++ b/pypush/apns/_util.py @@ -19,7 +19,7 @@ async def broadcast(self, packet): await stream.send(packet) except anyio.BrokenResourceError: logging.error("Broken resource error") - #self.streams.remove(stream) + # self.streams.remove(stream) @asynccontextmanager async def open_stream(self): diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index cd3ae29..a8c0155 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -25,8 +25,9 @@ async def create_apns_connection( courier: typing.Optional[str] = None, ): async with anyio.create_task_group() as tg: - conn = Connection(tg, certificate, private_key, token, courier)\ - # Await connected for first time here, so that base token is set + conn = Connection( + tg, certificate, private_key, token, courier + ) # Await connected for first time here, so that base token is set await conn._connected.wait() yield conn tg.cancel_scope.cancel() # Cancel the task group when the context manager exits @@ -47,7 +48,7 @@ def __init__( self.private_key = private_key self.base_token = token - self._connected = anyio.Event() # Set when the connection is first established + self._connected = anyio.Event() # Set when the connection is first established self._conn = None self._tg = task_group @@ -98,7 +99,7 @@ async def reconnect(self): protocol.ConnectCommand( push_token=self.base_token, state=1, - flags=65, #69 + flags=65, # 69 certificate=cert, nonce=nonce, signature=signature, @@ -152,12 +153,21 @@ async def send(self, command: protocol.Command): async def filter(self, topics: list[str]): assert self.base_token is not None - await self.send(protocol.FilterCommand(token=self.base_token, enabled_topic_hashes=[sha1(topic.encode()).digest() for topic in topics])) + await self.send( + protocol.FilterCommand( + token=self.base_token, + enabled_topic_hashes=[ + sha1(topic.encode()).digest() for topic in topics + ], + ) + ) async def request_scoped_token(self, topic: str) -> bytes: topic_hash = sha1(topic.encode()).digest() assert self.base_token is not None - await self.send(protocol.ScopedTokenCommand(token=self.base_token, topic=topic_hash)) + await self.send( + protocol.ScopedTokenCommand(token=self.base_token, topic=topic_hash) + ) ack = await self.receive(protocol.ScopedTokenAck) assert ack.status == 0 return ack.scoped_token diff --git a/pypush/apns/protocol.py b/pypush/apns/protocol.py index ce04998..244cda3 100644 --- a/pypush/apns/protocol.py +++ b/pypush/apns/protocol.py @@ -2,7 +2,7 @@ from hashlib import sha1 from typing import Optional, Union -from anyio.abc import ByteStream, ObjectStream +from anyio.abc import ObjectStream from pypush.apns._protocol import command, fid from pypush.apns.transport import Packet @@ -140,6 +140,7 @@ class KeepAliveAck(Command): PacketType = Packet.Type.KeepAliveAck unknown: Optional[int] = fid(1) + @command @dataclass class SetStateCommand(Command): @@ -228,6 +229,7 @@ class SendMessageAck(Command): token: Optional[bytes] = fid(1, default=None) unknown6: Optional[bytes] = fid(6, default=None) + @command @dataclass class ScopedTokenCommand(Command): @@ -237,6 +239,7 @@ class ScopedTokenCommand(Command): topic: bytes = fid(2) app_id: Optional[bytes] = fid(3, default=None) + @command @dataclass class ScopedTokenAck(Command): @@ -259,7 +262,7 @@ def from_packet(cls, packet: Packet): def to_packet(self) -> Packet: return Packet(id=self.id, fields=self.fields) - + def __repr__(self): if self.id.value in [29, 30, 32]: return f"UnknownCommand(id={self.id}, fields=[SUPPRESSED])" diff --git a/pypush/apns/transport.py b/pypush/apns/transport.py index 2b8c40c..ab5abd2 100644 --- a/pypush/apns/transport.py +++ b/pypush/apns/transport.py @@ -40,20 +40,19 @@ def __new__(cls, value): obj = object.__new__(cls) obj._value_ = value return obj - + @classmethod def _missing_(cls, value): # Handle unknown values instance = cls.UNKNOWN instance._value_ = value # Assign the unknown value return instance - + def __str__(self): if self is Packet.Type.UNKNOWN: return f"Unknown({self._value_})" return self.name - id: Type fields: list[Field] diff --git a/tests/assets/dev.jjtech.pypush.tests.pem b/tests/assets/dev.jjtech.pypush.tests.pem new file mode 100644 index 0000000..0188045 --- /dev/null +++ b/tests/assets/dev.jjtech.pypush.tests.pem @@ -0,0 +1,75 @@ +Bag Attributes + friendlyName: Apple Sandbox Push Services: dev.jjtech.pypush.tests + localKeyID: 0A C9 4D 65 F1 39 44 73 5F A8 05 BC B9 00 47 14 2C 12 9A F3 +subject=UID=dev.jjtech.pypush.tests, CN=Apple Sandbox Push Services: dev.jjtech.pypush.tests, OU=C4492JYJR3, C=US +issuer=CN=Apple Worldwide Developer Relations Certification Authority, OU=G4, O=Apple Inc., C=US +-----BEGIN CERTIFICATE----- +MIIGnzCCBYegAwIBAgIQRLQgelpeA0ozi3PDbx2ZmTANBgkqhkiG9w0BAQsFADB1 +MUQwQgYDVQQDDDtBcHBsZSBXb3JsZHdpZGUgRGV2ZWxvcGVyIFJlbGF0aW9ucyBD +ZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTELMAkGA1UECwwCRzQxEzARBgNVBAoMCkFw +cGxlIEluYy4xCzAJBgNVBAYTAlVTMB4XDTI0MDUxNjAwMTUwM1oXDTI1MDYxNTAw +MTUwMlowgYoxJzAlBgoJkiaJk/IsZAEBDBdkZXYuamp0ZWNoLnB5cHVzaC50ZXN0 +czE9MDsGA1UEAww0QXBwbGUgU2FuZGJveCBQdXNoIFNlcnZpY2VzOiBkZXYuamp0 +ZWNoLnB5cHVzaC50ZXN0czETMBEGA1UECwwKQzQ0OTJKWUpSMzELMAkGA1UEBhMC +VVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQD3BvhGnrBtXpVLVvdi +HFHYeu58MKBD/vyw3A+a4PXnCXskSdEZDydBXJnKa1OeIqn/7TG5/6iiWGR+pcYa +XK6kCka8fxpuWgk4/H7C2EN9Atv/XgJit3RSUFdKVN1dvG5cDX5yvFcu7xSt8J+Y +RHuqM2YGwor1bZNUCi46n144dntB9rEV2ZgLwrHc2ofo/STbdstGKMJHkhg0GVcI +0IzGderz1Ga1UXB8yhr+CvQthjcm74G+aQJZfuMsGwXI06wbKOJQPtCPdAD0taBW +rdHivETxRw3WhPzmiwQLUruOmXEo5+bgl1NhnPCLJn374LWaxQEzpnW2HhP6p8mC +TzZhAgMBAAGjggMTMIIDDzAMBgNVHRMBAf8EAjAAMB8GA1UdIwQYMBaAFFvZ+h3n +mhoLo5l2IlCGPpHIW3eoMHAGCCsGAQUFBwEBBGQwYjAtBggrBgEFBQcwAoYhaHR0 +cDovL2NlcnRzLmFwcGxlLmNvbS93d2RyZzQuZGVyMDEGCCsGAQUFBzABhiVodHRw +Oi8vb2NzcC5hcHBsZS5jb20vb2NzcDAzLXd3ZHJnNDAzMIIBHgYDVR0gBIIBFTCC +AREwggENBgkqhkiG92NkBQEwgf8wgcMGCCsGAQUFBwICMIG2DIGzUmVsaWFuY2Ug +b24gdGhpcyBjZXJ0aWZpY2F0ZSBieSBhbnkgcGFydHkgYXNzdW1lcyBhY2NlcHRh +bmNlIG9mIHRoZSB0aGVuIGFwcGxpY2FibGUgc3RhbmRhcmQgdGVybXMgYW5kIGNv +bmRpdGlvbnMgb2YgdXNlLCBjZXJ0aWZpY2F0ZSBwb2xpY3kgYW5kIGNlcnRpZmlj +YXRpb24gcHJhY3RpY2Ugc3RhdGVtZW50cy4wNwYIKwYBBQUHAgEWK2h0dHBzOi8v +d3d3LmFwcGxlLmNvbS9jZXJ0aWZpY2F0ZWF1dGhvcml0eS8wEwYDVR0lBAwwCgYI +KwYBBQUHAwIwMgYDVR0fBCswKTAnoCWgI4YhaHR0cDovL2NybC5hcHBsZS5jb20v +d3dkcmc0LTMuY3JsMB0GA1UdDgQWBBQKyU1l8TlEc1+oBby5AEcULBKa8zAOBgNV +HQ8BAf8EBAMCB4Awgb8GCiqGSIb3Y2QGAwYEgbAwga0MF2Rldi5qanRlY2gucHlw +dXNoLnRlc3RzMAcMBXRvcGljDBxkZXYuamp0ZWNoLnB5cHVzaC50ZXN0cy52b2lw +MAYMBHZvaXAMJGRldi5qanRlY2gucHlwdXNoLnRlc3RzLmNvbXBsaWNhdGlvbjAO +DAxjb21wbGljYXRpb24MIGRldi5qanRlY2gucHlwdXNoLnRlc3RzLnZvaXAtcHR0 +MAsMCS52b2lwLXB0dDAQBgoqhkiG92NkBgMBBAIFADANBgkqhkiG9w0BAQsFAAOC +AQEAwQac2q1BMnAH1vdZgfDunc+b7SKO6rJIG6w/wl4211YyNBBS5oabQnQDfB8y +8iOeWnoWXry60gI2fwWN/rRaQn4QCy72jNeTGz/T/s2jwoGj89114JjcBhRAHvQl +/HN4QjSt5rWVRcxTE4cKKbJIqVCm7Uq9VROgbxXrmsZsRnyk1ASvLGboibtGbmty +wmXZWns5NXNDbv1wP+PF5HSFXtDWodPYnhvzJe0s9lRvo4yGAt1KL5mNaZM3kKp0 +74kdzKK/iT7954EQK4ZWPQbDnS1A+/BzHQjK0rWTwjDQkbKvNE9bb+KJbNHH3+DX +5s0ybZYoG5meGKUplwu7A2bfFw== +-----END CERTIFICATE----- +Bag Attributes + friendlyName: Apple Sandbox Push Services: dev.jjtech.pypush.tests Private Key + localKeyID: 0A C9 4D 65 F1 39 44 73 5F A8 05 BC B9 00 47 14 2C 12 9A F3 +Key Attributes: +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQD3BvhGnrBtXpVL +VvdiHFHYeu58MKBD/vyw3A+a4PXnCXskSdEZDydBXJnKa1OeIqn/7TG5/6iiWGR+ +pcYaXK6kCka8fxpuWgk4/H7C2EN9Atv/XgJit3RSUFdKVN1dvG5cDX5yvFcu7xSt +8J+YRHuqM2YGwor1bZNUCi46n144dntB9rEV2ZgLwrHc2ofo/STbdstGKMJHkhg0 +GVcI0IzGderz1Ga1UXB8yhr+CvQthjcm74G+aQJZfuMsGwXI06wbKOJQPtCPdAD0 +taBWrdHivETxRw3WhPzmiwQLUruOmXEo5+bgl1NhnPCLJn374LWaxQEzpnW2HhP6 +p8mCTzZhAgMBAAECggEBAKADb8eu+3GdFvAagVyYI5wq5Vik1uu0vFKD+cfFeQQT +bCTxe/TTkAYSybwJEb0Zjy0spE1rgfzHbTFsiIqDBs1TqsZnPuPEhrzXMfVcyTqt +I3yjlMAFPeAkEqcfmdUiPgp64zHHNmI8lBSoDXlAwypY6PnwArtAI3MItTFcElhX +gWB44xVGuJRjRP4UVqXg0ML/Ic2yuYT9DRsDRilYhm8RGRSHkdZKdzCicMZcLtC7 +bs6/evmIrk9V5AzF6YiXlfT0dOp6yy9mFwhLljXF3Z2/LdrOTAmhLPQRMbUrJrcW +ZPd0kMybGIlEoprQEA/6nZkdtIiDo2OJtufCs8g+nJECgYEA/+v4uTJzEI1igKOB +myJtADECZAsJUaJaKSAM7VHn1hNOKgNLhUHOuroWvIWEhEomWeMvCbZIG42eOwNW +BXGtG7ruT79E6655dljU6E/029FaxONqXXCTD9ZPh031R293KcydMwgBJJ0pvFJE +14HWmMRAG0auPygMRhXubXU1ndMCgYEA9xpNWrl9poTjsZDNqvu60nYcq0W1escw +ovmb87uxZ5u8fC8T1F3AVMYj4v0dTyA4F0mZenY+nri/hJBuanWVxa5Liu0fGnBr +tEa2rzCMaajoDTNMKSygFz6CIMZbbZhozy0+9DHcRcC6b2UtIgB/+/ZQtrTvQ8Ea +i6viarkq1nsCgYBznYAM8mynEqhoYvV/RyslBf8FgTLhjU3b/F26rODmhmwucLSi +a9tf4ge5fTwjo3f17btnUND8mZrdICGxbex9dZKJtmgFbRn0TCdLGCwPTmIKRo7b +zaqyYeglwSNI9WNJH+X4kuopR1L+f9AX59ExzJ8Fc4XuhEIfO3MuQeBJ/wKBgQDa +8AgH0X/+EZJ42rcPvxiprxL5wbrpPSHf1M+T5gJqrXcUhNXJ/QMTWbekP+Y/HGn2 +YDTHZ4tWMJUoTJw4YVTBoQu33R8I2wDi6yCkGpzeZVStlXzuomZ6Ed1UUsvhT//V +SN6VmLP1ba0CVB/oF49OXNDpAWlZm/f8NuBW9Rd6jwKBgQDi495IOjLJ8SvWRJLT +c9AUmO7IVgipWvr51cF9IYxkzXIVIQIh1usy2NsrBxshAD+FbbWFVBfoptdKBZVK +J8u+Ou4gTxs8SdGKGZWZpUMEKJbPsq8lE2aU3mBXiWcFRxYpu+n7nKap0Lla/xBD +v77FY1M3FxGR6rNqPJQ9rRLFbA== +-----END PRIVATE KEY----- diff --git a/tests/test_apns.py b/tests/test_apns.py index 4978fe8..86b866c 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -1,18 +1,13 @@ -import pytest -from pypush import apns -import asyncio - -# from aioapns import * +import logging import uuid -import anyio - -# from pypush.apns import _util -# from pypush.apns import albert, lifecycle, protocol -from pypush import apns +from pathlib import Path -import logging +import httpx +import pytest from rich.logging import RichHandler +from pypush import apns + logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()], format="%(message)s") @@ -31,6 +26,7 @@ async def test_lifecycle_2(): ) as connection: pass + @pytest.mark.asyncio async def test_shorthand(): async with apns.create_apns_connection( @@ -38,17 +34,51 @@ async def test_shorthand(): ) as connection: pass + +ASSETS_DIR = Path(__file__).parent / "assets" + + +async def send_test_notification(device_token, payload=b"hello, world"): + async with httpx.AsyncClient( + cert=str(ASSETS_DIR / "dev.jjtech.pypush.tests.pem"), http2=True + ) as client: + # Use the certificate and key from above + response = await client.post( + f"https://api.sandbox.push.apple.com/3/device/{device_token}", + content=payload, + headers={ + "apns-topic": "dev.jjtech.pypush.tests", + "apns-push-type": "alert", + "apns-priority": "10", + }, + ) + assert response.status_code == 200 + + @pytest.mark.asyncio async def test_scoped_token(): async with apns.create_apns_connection( *await apns.activate(), courier="1-courier.sandbox.push.apple.com" ) as connection: + token = await connection.request_scoped_token("dev.jjtech.pypush.tests") - logging.warning(f"Got token: {token.hex()}") + + logging.debug(f"Got token: {token.hex()}") await connection.filter(["dev.jjtech.pypush.tests"]) - logging.warning(f"waiting on topic 'dev.jjtech.pypush.tests'") + logging.debug(f"waiting on topic 'dev.jjtech.pypush.tests'") + + test_message = f"test-message-{uuid.uuid4().hex}" + + await send_test_notification(token.hex(), test_message.encode()) + logging.debug(f"Sent message: {test_message}") + async with connection._broadcast.open_stream() as stream: async for command in stream: - if isinstance(command, apns.protocol.SendMessageCommand) and command.topic == "dev.jjtech.pypush.tests" and command.token == token: - logging.warning(f"Got message: {command.payload.decode()}") - break + if ( + isinstance(command, apns.protocol.SendMessageCommand) + and command.topic == "dev.jjtech.pypush.tests" + and command.token == token + ): + logging.debug(f"Got message: {command.payload.decode()}") + if command.payload == test_message.encode(): + break From 36f535103b9a9968aa34e3eb3d08bc70cfc6a436 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sat, 18 May 2024 14:39:43 -0400 Subject: [PATCH 04/19] apns: lifecycle: filtered streams --- pypush/apns/_util.py | 23 +++++++++++++++++++++-- pypush/apns/lifecycle.py | 16 +++++----------- tests/test_apns.py | 15 ++++++--------- 3 files changed, 32 insertions(+), 22 deletions(-) diff --git a/pypush/apns/_util.py b/pypush/apns/_util.py index c98d8d0..91d55ee 100644 --- a/pypush/apns/_util.py +++ b/pypush/apns/_util.py @@ -1,9 +1,9 @@ import logging from contextlib import asynccontextmanager -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Type import anyio -from anyio.abc import ObjectSendStream +from anyio.abc import ObjectSendStream, ObjectReceiveStream T = TypeVar("T") @@ -31,6 +31,25 @@ async def open_stream(self): await send.aclose() +W = TypeVar("W") +F = TypeVar("F", covariant=True) + + +class FilteredStream(ObjectReceiveStream[F]): + def __init__(self, source: ObjectReceiveStream[W], filter: Type[F]): + self.source = source + self.filter = filter + + async def receive(self) -> F: + async for item in self.source: + if isinstance(item, self.filter): + return item + raise anyio.EndOfStream + + async def aclose(self): + await self.source.aclose() + + def exponential_backoff(f): async def wrapper(*args, **kwargs): backoff = 1 diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index a8c0155..150cc8a 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -7,6 +7,7 @@ import typing from contextlib import asynccontextmanager from hashlib import sha1 +from dataclasses import dataclass import anyio from anyio.abc import TaskGroup @@ -123,17 +124,10 @@ async def aclose(self): T = typing.TypeVar("T", bound=protocol.Command) - # async def receive_stream( - # self, filter: typing.Type[T], max: int = -1 - # ) -> typing.AsyncIterator[T]: - # async with self._broadcast.open_stream() as stream: - # async for command in stream: - # if isinstance(command, filter): - # max -= 1 - # yield command - # if max == 0: - # break - # logging.error("Stream ended") # BUG: Will never happen, async iterators don't autoclose + @asynccontextmanager + async def receive_stream(self, filter: typing.Type[T]): + async with self._broadcast.open_stream() as stream: + yield _util.FilteredStream(stream, filter) async def receive(self, filter: typing.Type[T]) -> T: async with self._broadcast.open_stream() as stream: diff --git a/tests/test_apns.py b/tests/test_apns.py index 86b866c..a2a8b91 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -72,13 +72,10 @@ async def test_scoped_token(): await send_test_notification(token.hex(), test_message.encode()) logging.debug(f"Sent message: {test_message}") - async with connection._broadcast.open_stream() as stream: + async with connection.receive_stream( + apns.protocol.SendMessageCommand + ) as stream: async for command in stream: - if ( - isinstance(command, apns.protocol.SendMessageCommand) - and command.topic == "dev.jjtech.pypush.tests" - and command.token == token - ): - logging.debug(f"Got message: {command.payload.decode()}") - if command.payload == test_message.encode(): - break + logging.debug(f"Got message: {command.payload.decode()}") + if command.payload == test_message.encode(): + break From 96b75ff307fd9f77f18fdd8e113879dfd2ac9e32 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sat, 18 May 2024 15:24:19 -0400 Subject: [PATCH 05/19] apns: use receive_stream over receive to avoid race --- pypush/apns/_util.py | 4 +++- pypush/apns/lifecycle.py | 26 +++++++++++++++++++++++--- tests/test_apns.py | 31 +++++++++++++------------------ 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/pypush/apns/_util.py b/pypush/apns/_util.py index 91d55ee..98a41c6 100644 --- a/pypush/apns/_util.py +++ b/pypush/apns/_util.py @@ -23,7 +23,9 @@ async def broadcast(self, packet): @asynccontextmanager async def open_stream(self): - send, recv = anyio.create_memory_object_stream[T]() + # 1000 seems like a reasonable number, if more than 1000 messages come in before someone deals with them it will + # start stalling the APNs connection itself + send, recv = anyio.create_memory_object_stream[T](max_buffer_size=1000) self.streams.append(send) async with recv: yield recv diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 150cc8a..0e6d5ee 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -49,6 +49,8 @@ def __init__( self.private_key = private_key self.base_token = token + self._filters: dict[str, int] = {} # topic -> use count + self._connected = anyio.Event() # Set when the connection is first established self._conn = None @@ -130,6 +132,11 @@ async def receive_stream(self, filter: typing.Type[T]): yield _util.FilteredStream(stream, filter) async def receive(self, filter: typing.Type[T]) -> T: + """ + WARNING: If the actions between whatever triggered the thing you want to receive and this call might take a long time, + + you should use `receive_stream` instead, as any messages that arrive in between will be lost! + """ async with self._broadcast.open_stream() as stream: async for command in stream: if isinstance(command, filter): @@ -145,18 +152,31 @@ async def send(self, command: protocol.Command): await self.reconnect() await self.send(command) - async def filter(self, topics: list[str]): + async def _update_filter(self): assert self.base_token is not None await self.send( protocol.FilterCommand( token=self.base_token, enabled_topic_hashes=[ - sha1(topic.encode()).digest() for topic in topics + sha1(topic.encode()).digest() for topic in self._filters ], ) ) - async def request_scoped_token(self, topic: str) -> bytes: + @asynccontextmanager + async def _filter(self, topics: list[str]): + assert self.base_token is not None + for topic in topics: + self._filters[topic] = self._filters.get(topic, 0) + 1 + await self._update_filter() + yield + for topic in topics: + self._filters[topic] -= 1 + if self._filters[topic] == 0: + del self._filters[topic] + await self._update_filter() + + async def mint_scoped_token(self, topic: str) -> bytes: topic_hash = sha1(topic.encode()).digest() assert self.base_token is not None await self.send( diff --git a/tests/test_apns.py b/tests/test_apns.py index a2a8b91..d758ec8 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -61,21 +61,16 @@ async def test_scoped_token(): *await apns.activate(), courier="1-courier.sandbox.push.apple.com" ) as connection: - token = await connection.request_scoped_token("dev.jjtech.pypush.tests") - - logging.debug(f"Got token: {token.hex()}") - await connection.filter(["dev.jjtech.pypush.tests"]) - logging.debug(f"waiting on topic 'dev.jjtech.pypush.tests'") - - test_message = f"test-message-{uuid.uuid4().hex}" - - await send_test_notification(token.hex(), test_message.encode()) - logging.debug(f"Sent message: {test_message}") - - async with connection.receive_stream( - apns.protocol.SendMessageCommand - ) as stream: - async for command in stream: - logging.debug(f"Got message: {command.payload.decode()}") - if command.payload == test_message.encode(): - break + token = await connection.mint_scoped_token("dev.jjtech.pypush.tests") + + async with connection._filter(["dev.jjtech.pypush.tests"]): + test_message = f"test-message-{uuid.uuid4().hex}" + + # Must use a receive_stream because the notification might arrive before the HTTP response + async with connection.receive_stream( + apns.protocol.SendMessageCommand + ) as stream: + await send_test_notification(token.hex(), test_message.encode()) + async for command in stream: + if command.payload == test_message.encode(): + break From 17d0b6c4de6d573e8e5f10062fa272bc4f6751ee Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sat, 18 May 2024 18:32:35 -0400 Subject: [PATCH 06/19] apns: refactor + BroadcastStream now has a backlog to eliminate common races + FilteredStream now takes any Filter + Filters are now chainable --- pypush/apns/__init__.py | 3 ++- pypush/apns/_util.py | 28 ++++++++++++++++++++------ pypush/apns/filters.py | 43 ++++++++++++++++++++++++++++++++++++++++ pypush/apns/lifecycle.py | 26 ++++++++++-------------- tests/test_apns.py | 16 +++++++-------- 5 files changed, 86 insertions(+), 30 deletions(-) create mode 100644 pypush/apns/filters.py diff --git a/pypush/apns/__init__.py b/pypush/apns/__init__.py index ff6398a..b8d79d9 100644 --- a/pypush/apns/__init__.py +++ b/pypush/apns/__init__.py @@ -1,5 +1,6 @@ -__all__ = ["protocol", "create_apns_connection", "activate"] +__all__ = ["protocol", "create_apns_connection", "activate", "filters"] from . import protocol from .lifecycle import create_apns_connection from .albert import activate +from . import filters diff --git a/pypush/apns/_util.py b/pypush/apns/_util.py index 98a41c6..e9816fa 100644 --- a/pypush/apns/_util.py +++ b/pypush/apns/_util.py @@ -1,16 +1,20 @@ import logging from contextlib import asynccontextmanager -from typing import Generic, TypeVar, Type +from typing import Generic, TypeVar, Type, Callable, Optional import anyio from anyio.abc import ObjectSendStream, ObjectReceiveStream +from . import filters + T = TypeVar("T") class BroadcastStream(Generic[T]): - def __init__(self): + def __init__(self, backlog: int = 50): self.streams: list[ObjectSendStream[T]] = [] + self.backlog: list[T] = [] + self._backlog_size = backlog async def broadcast(self, packet): logging.debug(f"Broadcasting {packet} to {len(self.streams)} streams") @@ -20,12 +24,18 @@ async def broadcast(self, packet): except anyio.BrokenResourceError: logging.error("Broken resource error") # self.streams.remove(stream) + # If we have a backlog, add the packet to it + if len(self.backlog) >= self._backlog_size: + self.backlog.pop(0) + self.backlog.append(packet) @asynccontextmanager async def open_stream(self): # 1000 seems like a reasonable number, if more than 1000 messages come in before someone deals with them it will # start stalling the APNs connection itself send, recv = anyio.create_memory_object_stream[T](max_buffer_size=1000) + for packet in self.backlog: + await send.send(packet) self.streams.append(send) async with recv: yield recv @@ -34,18 +44,24 @@ async def open_stream(self): W = TypeVar("W") -F = TypeVar("F", covariant=True) +F = TypeVar("F") class FilteredStream(ObjectReceiveStream[F]): - def __init__(self, source: ObjectReceiveStream[W], filter: Type[F]): + """ + A stream that filters out unwanted items + + filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it + """ + + def __init__(self, source: ObjectReceiveStream[W], filter: filters.Filter[W, F]): self.source = source self.filter = filter async def receive(self) -> F: async for item in self.source: - if isinstance(item, self.filter): - return item + if (filtered := self.filter(item)) is not None: + return filtered raise anyio.EndOfStream async def aclose(self): diff --git a/pypush/apns/filters.py b/pypush/apns/filters.py new file mode 100644 index 0000000..0b9f5fc --- /dev/null +++ b/pypush/apns/filters.py @@ -0,0 +1,43 @@ +from pypush.apns import protocol +from typing import TypeVar, Optional, Type, Callable + +# def chain(*filters): +# def filter(command: protocol.Command) -> Optional[protocol.Command]: +# for f in filters: +# command = f(command) +# if command is None: +# return None +# return command +# return filter + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +Filter = Callable[[T1], Optional[T2]] +# typing.Callable[[protocol.Command], typing.Optional[T]] + +# Chain with proper types so that subsequent filters only need to take output type of previous filter +T_IN = TypeVar("T_IN", bound=protocol.Command) +T_MIDDLE = TypeVar("T_MIDDLE", bound=protocol.Command) +T_OUT = TypeVar("T_OUT", bound=protocol.Command) + + +def chain(first: Filter[T_IN, T_MIDDLE], second: Filter[T_MIDDLE, T_OUT]): + def filter(command: T_IN) -> Optional[T_OUT]: + filtered = first(command) + if filtered is None: + return None + return second(filtered) + + return filter + + +T = TypeVar("T", bound=protocol.Command) + + +def cmd(type: Type[T]) -> Filter[protocol.Command, T]: + def filter(command: protocol.Command) -> Optional[T]: + if isinstance(command, type): + return command + return None + + return filter diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 0e6d5ee..9004f4f 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -15,7 +15,7 @@ from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding, rsa -from . import protocol, transport, _util +from . import protocol, transport, _util, filters @asynccontextmanager @@ -78,7 +78,7 @@ async def _ping_task(self): await anyio.sleep(30) logging.debug("Sending keepalive") await self.send(protocol.KeepAliveCommand()) - await self.receive(protocol.KeepAliveAck) + await self.receive(filters.cmd(protocol.KeepAliveAck)) @_util.exponential_backoff async def reconnect(self): @@ -109,7 +109,7 @@ async def reconnect(self): ) ) self._tg.start_soon(self._receive_task) - ack = await self.receive(protocol.ConnectAck) + ack = await self.receive(filters.cmd(protocol.ConnectAck)) logging.debug(f"Connected with ack: {ack}") assert ack.status == 0 if self.base_token is None: @@ -124,23 +124,19 @@ async def aclose(self): await self._conn.aclose() # Note: Will be reopened if task group is still running and ping task is still running - T = typing.TypeVar("T", bound=protocol.Command) + T = typing.TypeVar("T") @asynccontextmanager - async def receive_stream(self, filter: typing.Type[T]): + async def receive_stream( + self, filter: filters.Filter[protocol.Command, T] = lambda c: c + ): async with self._broadcast.open_stream() as stream: yield _util.FilteredStream(stream, filter) - async def receive(self, filter: typing.Type[T]) -> T: - """ - WARNING: If the actions between whatever triggered the thing you want to receive and this call might take a long time, - - you should use `receive_stream` instead, as any messages that arrive in between will be lost! - """ - async with self._broadcast.open_stream() as stream: + async def receive(self, filter: filters.Filter[protocol.Command, T]): + async with self.receive_stream(filter) as stream: async for command in stream: - if isinstance(command, filter): - return command + return command raise ValueError("Did not receive expected command") async def send(self, command: protocol.Command): @@ -182,6 +178,6 @@ async def mint_scoped_token(self, topic: str) -> bytes: await self.send( protocol.ScopedTokenCommand(token=self.base_token, topic=topic_hash) ) - ack = await self.receive(protocol.ScopedTokenAck) + ack = await self.receive(filters.cmd(protocol.ScopedTokenAck)) assert ack.status == 0 return ack.scoped_token diff --git a/tests/test_apns.py b/tests/test_apns.py index d758ec8..f79bce9 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -66,11 +66,11 @@ async def test_scoped_token(): async with connection._filter(["dev.jjtech.pypush.tests"]): test_message = f"test-message-{uuid.uuid4().hex}" - # Must use a receive_stream because the notification might arrive before the HTTP response - async with connection.receive_stream( - apns.protocol.SendMessageCommand - ) as stream: - await send_test_notification(token.hex(), test_message.encode()) - async for command in stream: - if command.payload == test_message.encode(): - break + await send_test_notification(token.hex(), test_message.encode()) + + resp = await connection.receive( + apns.filters.chain( + apns.filters.cmd(apns.protocol.SendMessageCommand), + lambda x: x if x.payload == test_message.encode() else None, + ) + ) From a065fdfbf25c8eb4d72f2bca6c7ae67c8541ca43 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sat, 18 May 2024 18:47:52 -0400 Subject: [PATCH 07/19] apns: fix a few potential backlog issues --- pypush/apns/_util.py | 7 ++--- pypush/apns/lifecycle.py | 57 +++++++++++++++++++++++++++++----------- tests/test_apns.py | 2 +- 3 files changed, 46 insertions(+), 20 deletions(-) diff --git a/pypush/apns/_util.py b/pypush/apns/_util.py index e9816fa..e4a318d 100644 --- a/pypush/apns/_util.py +++ b/pypush/apns/_util.py @@ -30,12 +30,13 @@ async def broadcast(self, packet): self.backlog.append(packet) @asynccontextmanager - async def open_stream(self): + async def open_stream(self, backlog: bool = True): # 1000 seems like a reasonable number, if more than 1000 messages come in before someone deals with them it will # start stalling the APNs connection itself send, recv = anyio.create_memory_object_stream[T](max_buffer_size=1000) - for packet in self.backlog: - await send.send(packet) + if backlog: + for packet in self.backlog: + await send.send(packet) self.streams.append(send) async with recv: yield recv diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 9004f4f..9267660 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -7,7 +7,6 @@ import typing from contextlib import asynccontextmanager from hashlib import sha1 -from dataclasses import dataclass import anyio from anyio.abc import TaskGroup @@ -77,8 +76,10 @@ async def _ping_task(self): while True: await anyio.sleep(30) logging.debug("Sending keepalive") - await self.send(protocol.KeepAliveCommand()) - await self.receive(filters.cmd(protocol.KeepAliveAck)) + await self._send(protocol.KeepAliveCommand()) + await self._receive( + filters.cmd(protocol.KeepAliveAck), backlog=False + ) # Explicitly disable the backlog since we don't want to receive old acks @_util.exponential_backoff async def reconnect(self): @@ -86,7 +87,10 @@ async def reconnect(self): if self._conn is not None: logging.warning("Closing existing connection") await self._conn.aclose() - self._conn = protocol.CommandStream( + + self._broadcast.backlog = [] # Clear the backlog + + conn = protocol.CommandStream( await transport.create_courier_connection(courier=self.courier) ) cert = self.certificate.public_bytes(serialization.Encoding.DER) @@ -98,7 +102,7 @@ async def reconnect(self): signature = b"\x01\x01" + self.private_key.sign( nonce, padding.PKCS1v15(), hashes.SHA1() ) - await self._conn.send( + await conn.send( protocol.ConnectCommand( push_token=self.base_token, state=1, @@ -108,8 +112,25 @@ async def reconnect(self): signature=signature, ) ) + + # Don't set self._conn until we've sent the connect command + self._conn = conn + self._tg.start_soon(self._receive_task) - ack = await self.receive(filters.cmd(protocol.ConnectAck)) + ack = await self._receive( + filters.chain( + filters.cmd(protocol.ConnectAck), + lambda c: ( + c + if ( + c.token == self.base_token + if self.base_token is not None + else True + ) + else None + ), + ) + ) logging.debug(f"Connected with ack: {ack}") assert ack.status == 0 if self.base_token is None: @@ -127,30 +148,34 @@ async def aclose(self): T = typing.TypeVar("T") @asynccontextmanager - async def receive_stream( - self, filter: filters.Filter[protocol.Command, T] = lambda c: c + async def _receive_stream( + self, + filter: filters.Filter[protocol.Command, T] = lambda c: c, + backlog: bool = True, ): - async with self._broadcast.open_stream() as stream: + async with self._broadcast.open_stream(backlog) as stream: yield _util.FilteredStream(stream, filter) - async def receive(self, filter: filters.Filter[protocol.Command, T]): - async with self.receive_stream(filter) as stream: + async def _receive( + self, filter: filters.Filter[protocol.Command, T], backlog: bool = True + ): + async with self._receive_stream(filter, backlog) as stream: async for command in stream: return command raise ValueError("Did not receive expected command") - async def send(self, command: protocol.Command): + async def _send(self, command: protocol.Command): try: assert self._conn is not None await self._conn.send(command) except Exception as e: logging.warning(f"Error sending command, reconnecting") await self.reconnect() - await self.send(command) + await self._send(command) async def _update_filter(self): assert self.base_token is not None - await self.send( + await self._send( protocol.FilterCommand( token=self.base_token, enabled_topic_hashes=[ @@ -175,9 +200,9 @@ async def _filter(self, topics: list[str]): async def mint_scoped_token(self, topic: str) -> bytes: topic_hash = sha1(topic.encode()).digest() assert self.base_token is not None - await self.send( + await self._send( protocol.ScopedTokenCommand(token=self.base_token, topic=topic_hash) ) - ack = await self.receive(filters.cmd(protocol.ScopedTokenAck)) + ack = await self._receive(filters.cmd(protocol.ScopedTokenAck)) assert ack.status == 0 return ack.scoped_token diff --git a/tests/test_apns.py b/tests/test_apns.py index f79bce9..b5c4c70 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -68,7 +68,7 @@ async def test_scoped_token(): await send_test_notification(token.hex(), test_message.encode()) - resp = await connection.receive( + resp = await connection._receive( apns.filters.chain( apns.filters.cmd(apns.protocol.SendMessageCommand), lambda x: x if x.payload == test_message.encode() else None, From dfbea52376f1b6e86d5234cc44eb219b13bb0d49 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sat, 18 May 2024 19:58:42 -0400 Subject: [PATCH 08/19] apsn: add filters and notification receive API --- pypush/apns/__init__.py | 5 ++--- pypush/apns/filters.py | 19 ++++++++----------- pypush/apns/lifecycle.py | 41 +++++++++++++++++++++++++++++++++++++++- tests/test_apns.py | 16 +++++++--------- 4 files changed, 57 insertions(+), 24 deletions(-) diff --git a/pypush/apns/__init__.py b/pypush/apns/__init__.py index b8d79d9..3c954b8 100644 --- a/pypush/apns/__init__.py +++ b/pypush/apns/__init__.py @@ -1,6 +1,5 @@ __all__ = ["protocol", "create_apns_connection", "activate", "filters"] -from . import protocol -from .lifecycle import create_apns_connection +from . import filters, protocol from .albert import activate -from . import filters +from .lifecycle import create_apns_connection diff --git a/pypush/apns/filters.py b/pypush/apns/filters.py index 0b9f5fc..148ebd0 100644 --- a/pypush/apns/filters.py +++ b/pypush/apns/filters.py @@ -1,19 +1,11 @@ -from pypush.apns import protocol -from typing import TypeVar, Optional, Type, Callable +import logging +from typing import Callable, Optional, Type, TypeVar -# def chain(*filters): -# def filter(command: protocol.Command) -> Optional[protocol.Command]: -# for f in filters: -# command = f(command) -# if command is None: -# return None -# return command -# return filter +from pypush.apns import protocol T1 = TypeVar("T1") T2 = TypeVar("T2") Filter = Callable[[T1], Optional[T2]] -# typing.Callable[[protocol.Command], typing.Optional[T]] # Chain with proper types so that subsequent filters only need to take output type of previous filter T_IN = TypeVar("T_IN", bound=protocol.Command) @@ -23,6 +15,7 @@ def chain(first: Filter[T_IN, T_MIDDLE], second: Filter[T_MIDDLE, T_OUT]): def filter(command: T_IN) -> Optional[T_OUT]: + logging.debug(f"Filtering {command} with {first} and {second}") filtered = first(command) if filtered is None: return None @@ -41,3 +34,7 @@ def filter(command: protocol.Command) -> Optional[T]: return None return filter + + +ALL = lambda c: c +NONE = lambda _: None diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 9267660..6f50b59 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -14,7 +14,7 @@ from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding, rsa -from . import protocol, transport, _util, filters +from . import _util, filters, protocol, transport @asynccontextmanager @@ -206,3 +206,42 @@ async def mint_scoped_token(self, topic: str) -> bytes: ack = await self._receive(filters.cmd(protocol.ScopedTokenAck)) assert ack.status == 0 return ack.scoped_token + + @asynccontextmanager + async def notification_stream( + self, + topic: str, + token: typing.Optional[bytes] = None, + filter: filters.Filter[ + protocol.SendMessageCommand, protocol.SendMessageCommand + ] = filters.ALL, + ): + if token is None: + token = self.base_token + async with self._filter([topic]): + async with self._receive_stream( + filters.chain( + filters.chain( + filters.chain( + filters.cmd(protocol.SendMessageCommand), + lambda c: c if c.token == token else None, + ), + lambda c: (c if c.topic == topic else None), + ), + filter, + ) + ) as stream: + yield stream + + async def expect_notification( + self, + topic: str, + token: typing.Optional[bytes] = None, + filter: filters.Filter[ + protocol.SendMessageCommand, protocol.SendMessageCommand + ] = filters.ALL, + ) -> protocol.SendMessageCommand: + async with self.notification_stream(topic, token, filter) as stream: + async for command in stream: + return command + raise ValueError("Did not receive expected notification") diff --git a/tests/test_apns.py b/tests/test_apns.py index b5c4c70..820a63e 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -63,14 +63,12 @@ async def test_scoped_token(): token = await connection.mint_scoped_token("dev.jjtech.pypush.tests") - async with connection._filter(["dev.jjtech.pypush.tests"]): - test_message = f"test-message-{uuid.uuid4().hex}" + test_message = f"test-message-{uuid.uuid4().hex}" - await send_test_notification(token.hex(), test_message.encode()) + await send_test_notification(token.hex(), test_message.encode()) - resp = await connection._receive( - apns.filters.chain( - apns.filters.cmd(apns.protocol.SendMessageCommand), - lambda x: x if x.payload == test_message.encode() else None, - ) - ) + await connection.expect_notification( + "dev.jjtech.pypush.tests", + token, + lambda c: c if c.payload == test_message.encode() else None, + ) From 12e06cd3b38fb5074f94b891512b679c3ca2ba31 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sat, 18 May 2024 20:01:20 -0400 Subject: [PATCH 09/19] pypush: format + isort --- pypush/apns/_protocol.py | 2 +- pypush/apns/_util.py | 4 ++-- pypush/apns/albert.py | 2 +- pypush/cli/_frida.py | 3 ++- pypush/cli/proxy.py | 3 +-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pypush/apns/_protocol.py b/pypush/apns/_protocol.py index 140d9e3..0d79a33 100644 --- a/pypush/apns/_protocol.py +++ b/pypush/apns/_protocol.py @@ -3,7 +3,7 @@ import logging from dataclasses import MISSING, field from dataclasses import fields as dataclass_fields -from typing import Any, TypeVar, get_origin, get_args, Union +from typing import Any, TypeVar, Union, get_args, get_origin from pypush.apns.transport import Packet diff --git a/pypush/apns/_util.py b/pypush/apns/_util.py index e4a318d..30bc634 100644 --- a/pypush/apns/_util.py +++ b/pypush/apns/_util.py @@ -1,9 +1,9 @@ import logging from contextlib import asynccontextmanager -from typing import Generic, TypeVar, Type, Callable, Optional +from typing import Callable, Generic, Optional, Type, TypeVar import anyio -from anyio.abc import ObjectSendStream, ObjectReceiveStream +from anyio.abc import ObjectReceiveStream, ObjectSendStream from . import filters diff --git a/pypush/apns/albert.py b/pypush/apns/albert.py index 024e449..459e79c 100644 --- a/pypush/apns/albert.py +++ b/pypush/apns/albert.py @@ -4,7 +4,7 @@ import re import uuid from base64 import b64decode -from typing import Tuple, Optional +from typing import Optional, Tuple import httpx from cryptography import x509 diff --git a/pypush/cli/_frida.py b/pypush/cli/_frida.py index dc30ce5..3a71ae4 100644 --- a/pypush/cli/_frida.py +++ b/pypush/cli/_frida.py @@ -1,6 +1,7 @@ -import frida import logging +import frida + def attach_to_apsd() -> frida.core.Session: frida.kill("apsd") diff --git a/pypush/cli/proxy.py b/pypush/cli/proxy.py index b801d43..50f2315 100644 --- a/pypush/cli/proxy.py +++ b/pypush/cli/proxy.py @@ -15,8 +15,7 @@ from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat # from pypush import apns -from pypush.apns import transport -from pypush.apns import protocol +from pypush.apns import protocol, transport from . import _frida From d99c8d2896c216423d5bca29ef28cbe703968a43 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sat, 18 May 2024 20:12:28 -0400 Subject: [PATCH 10/19] apns: ACK expected notifications automatically notification_stream does NOT automatically ACK, since the stream may later be filtered etc. --- pypush/apns/lifecycle.py | 13 ++++++++++--- pypush/apns/transport.py | 9 ++++++++- tests/test_apns.py | 4 ++-- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 6f50b59..d6ca59e 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -233,6 +233,13 @@ async def notification_stream( ) as stream: yield stream + async def ack(self, command: protocol.SendMessageCommand, status: int = 0): + await self._send( + protocol.SendMessageAck( + status=status, token=command.token, id=command.id + ) + ) + async def expect_notification( self, topic: str, @@ -242,6 +249,6 @@ async def expect_notification( ] = filters.ALL, ) -> protocol.SendMessageCommand: async with self.notification_stream(topic, token, filter) as stream: - async for command in stream: - return command - raise ValueError("Did not receive expected notification") + command = await stream.receive() + await self.ack(command) + return command \ No newline at end of file diff --git a/pypush/apns/transport.py b/pypush/apns/transport.py index ab5abd2..17774c1 100644 --- a/pypush/apns/transport.py +++ b/pypush/apns/transport.py @@ -66,13 +66,20 @@ async def create_courier_connection( context = ssl.create_default_context() context.set_alpn_protocols(ALPN) + # Special case for local testing + if courier == "sandbox-localhost": + courier = "localhost" + sni = "courier.sandbox.push.apple.com" + else: + sni = "courier.push.apple.com" + # TODO: Verify courier certificate context.check_hostname = False context.verify_mode = ssl.CERT_NONE return PacketStream( await anyio.connect_tcp( - courier, COURIER_PORT, ssl_context=context, tls_standard_compatible=False + courier, COURIER_PORT, ssl_context=context, tls_standard_compatible=False, tls_hostname=sni ) ) diff --git a/tests/test_apns.py b/tests/test_apns.py index 820a63e..6c10bd1 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -22,7 +22,7 @@ async def test_activate(): @pytest.mark.asyncio async def test_lifecycle_2(): async with apns.create_apns_connection( - certificate, key, courier="localhost" + certificate, key ) as connection: pass @@ -58,7 +58,7 @@ async def send_test_notification(device_token, payload=b"hello, world"): @pytest.mark.asyncio async def test_scoped_token(): async with apns.create_apns_connection( - *await apns.activate(), courier="1-courier.sandbox.push.apple.com" + *await apns.activate(), courier="sandbox-localhost" ) as connection: token = await connection.mint_scoped_token("dev.jjtech.pypush.tests") From 85909072d4ff699ab6f9c44319a0650f08acfd51 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sat, 18 May 2024 20:14:00 -0400 Subject: [PATCH 11/19] apns: fix sandbox indication via SNI --- pypush/apns/transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypush/apns/transport.py b/pypush/apns/transport.py index 17774c1..c153768 100644 --- a/pypush/apns/transport.py +++ b/pypush/apns/transport.py @@ -71,7 +71,7 @@ async def create_courier_connection( courier = "localhost" sni = "courier.sandbox.push.apple.com" else: - sni = "courier.push.apple.com" + sni = courier # TODO: Verify courier certificate context.check_hostname = False From 815baf595c8263bddf105b92ab7f8c9c8e40c210 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sat, 18 May 2024 20:39:08 -0400 Subject: [PATCH 12/19] apns: add notifications command --- pypush/cli/__init__.py | 35 +++++++++++++++++++++++++++++------ pypush/cli/pushclient.py | 0 2 files changed, 29 insertions(+), 6 deletions(-) delete mode 100644 pypush/cli/pushclient.py diff --git a/pypush/cli/__init__.py b/pypush/cli/__init__.py index 83e70a0..cc2523a 100644 --- a/pypush/cli/__init__.py +++ b/pypush/cli/__init__.py @@ -5,8 +5,11 @@ from typing_extensions import Annotated from . import proxy as _proxy +from pypush import apns +import anyio +from asyncio import CancelledError -logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()], format="%(message)s") +logging.basicConfig(level=logging.INFO, handlers=[RichHandler()], format="%(message)s") app = typer.Typer() @@ -22,12 +25,14 @@ def proxy( Attach requires SIP to be disabled and to be running as root """ - - _proxy.main(attach) + try: + _proxy.main(attach) + except CancelledError: + pass @app.command() -def client( +def notifications( topic: Annotated[str, typer.Argument(help="app topic to listen on")], sandbox: Annotated[ bool, typer.Option("--sandbox/--production", help="APNs courier to use") @@ -36,8 +41,26 @@ def client( """ Connect to the APNs courier and listen for app notifications on the given topic """ - typer.echo("Running APNs client") - raise NotImplementedError("Not implemented yet") + logging.getLogger("httpx").setLevel(logging.WARNING) + try: + anyio.run(notifications_async, topic, sandbox) + except CancelledError: + pass + +async def notifications_async(topic: str, sandbox: bool): + async with apns.create_apns_connection( + *await apns.activate(), courier="1-courier.sandbox.push.apple.com" if sandbox else "1-courier.push.apple.com" + ) as connection: + + token = await connection.mint_scoped_token(topic) + + async with connection.notification_stream(topic, token) as stream: + logging.info(f"Listening for notifications on topic {topic} ({'sandbox' if sandbox else 'production'})") + logging.info(f"Token: {token.hex()}") + + async for notification in stream: + await connection.ack(notification) + logging.info(notification.payload.decode()) def main(): diff --git a/pypush/cli/pushclient.py b/pypush/cli/pushclient.py deleted file mode 100644 index e69de29..0000000 From bd30ee9e1f14283086e5eda6596a2694e149f645 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sat, 18 May 2024 20:55:09 -0400 Subject: [PATCH 13/19] apns: send lock, update filters after reconnection --- pypush/apns/lifecycle.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index d6ca59e..42360ad 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -56,6 +56,7 @@ def __init__( self._tg = task_group self._broadcast = _util.BroadcastStream[protocol.Command]() self._reconnect_lock = anyio.Lock() + self._send_lock = anyio.Lock() if courier is None: # Pick a random courier server from 1 to 50 @@ -139,6 +140,8 @@ async def reconnect(self): assert ack.token == self.base_token if not self._connected.is_set(): self._connected.set() + + await self._update_filter() async def aclose(self): if self._conn is not None: @@ -166,8 +169,9 @@ async def _receive( async def _send(self, command: protocol.Command): try: - assert self._conn is not None - await self._conn.send(command) + async with self._send_lock: + assert self._conn is not None + await self._conn.send(command) except Exception as e: logging.warning(f"Error sending command, reconnecting") await self.reconnect() From 037fc28287b62e9ab2293e2d14478c4a917ad5df Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sun, 19 May 2024 13:01:58 -0400 Subject: [PATCH 14/19] apns: easier sandbox courier selection --- pypush/apns/lifecycle.py | 10 +++++++--- pypush/apns/transport.py | 7 +++---- pypush/cli/proxy.py | 2 +- tests/test_apns.py | 11 +---------- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 42360ad..1d5f7d1 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -22,11 +22,12 @@ async def create_apns_connection( certificate: x509.Certificate, private_key: rsa.RSAPrivateKey, token: typing.Optional[bytes] = None, + sandbox: bool = False, courier: typing.Optional[str] = None, ): async with anyio.create_task_group() as tg: conn = Connection( - tg, certificate, private_key, token, courier + tg, certificate, private_key, token, sandbox, courier ) # Await connected for first time here, so that base token is set await conn._connected.wait() yield conn @@ -41,6 +42,7 @@ def __init__( certificate: x509.Certificate, private_key: rsa.RSAPrivateKey, token: typing.Optional[bytes] = None, + sandbox: bool = False, courier: typing.Optional[str] = None, ): @@ -58,9 +60,11 @@ def __init__( self._reconnect_lock = anyio.Lock() self._send_lock = anyio.Lock() + self.sandbox = sandbox if courier is None: # Pick a random courier server from 1 to 50 - courier = f"{random.randint(1, 50)}-courier.push.apple.com" + courier = f"{random.randint(1, 50)}-courier.push.apple.com" if not sandbox else f"{random.randint(1, 10)}-courier.sandbox.push.apple.com" + logging.debug(f"Using courier: {courier}") self.courier = courier self._tg.start_soon(self.reconnect) @@ -92,7 +96,7 @@ async def reconnect(self): self._broadcast.backlog = [] # Clear the backlog conn = protocol.CommandStream( - await transport.create_courier_connection(courier=self.courier) + await transport.create_courier_connection(self.sandbox, self.courier) ) cert = self.certificate.public_bytes(serialization.Encoding.DER) nonce = ( diff --git a/pypush/apns/transport.py b/pypush/apns/transport.py index c153768..bdf54ae 100644 --- a/pypush/apns/transport.py +++ b/pypush/apns/transport.py @@ -61,17 +61,16 @@ def fields_for_id(self, id: int) -> list[bytes]: async def create_courier_connection( + sandbox: bool = False, courier: str = "1-courier.push.apple.com", ) -> PacketStream: context = ssl.create_default_context() context.set_alpn_protocols(ALPN) - # Special case for local testing - if courier == "sandbox-localhost": - courier = "localhost" + if sandbox: sni = "courier.sandbox.push.apple.com" else: - sni = courier + sni = "courier.push.apple.com" # TODO: Verify courier certificate context.check_hostname = False diff --git a/pypush/cli/proxy.py b/pypush/cli/proxy.py index 50f2315..e611112 100644 --- a/pypush/cli/proxy.py +++ b/pypush/cli/proxy.py @@ -70,7 +70,7 @@ async def handle(client: TLSStream): else "1-courier.sandbox.push.apple.com" ) name = f"prod-{connection_cnt}" if not sandbox else f"sandbox-{connection_cnt}" - async with await transport.create_courier_connection(forward) as conn: + async with await transport.create_courier_connection(sandbox, forward) as conn: logging.debug("Connected to courier") async with anyio.create_task_group() as tg: tg.start_soon(forward_packets, client_pkt, conn, f"client-{name}") diff --git a/tests/test_apns.py b/tests/test_apns.py index 6c10bd1..a3b6232 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -26,15 +26,6 @@ async def test_lifecycle_2(): ) as connection: pass - -@pytest.mark.asyncio -async def test_shorthand(): - async with apns.create_apns_connection( - *await apns.activate(), courier="localhost" - ) as connection: - pass - - ASSETS_DIR = Path(__file__).parent / "assets" @@ -58,7 +49,7 @@ async def send_test_notification(device_token, payload=b"hello, world"): @pytest.mark.asyncio async def test_scoped_token(): async with apns.create_apns_connection( - *await apns.activate(), courier="sandbox-localhost" + *await apns.activate(), sandbox=True ) as connection: token = await connection.mint_scoped_token("dev.jjtech.pypush.tests") From be8e3934d524d96407ea8829fb4b3f7cae574ba0 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sun, 19 May 2024 13:24:50 -0400 Subject: [PATCH 15/19] pypush: ruff check and format --- pyproject.toml | 6 +++++- pypush/apns/_protocol.py | 12 +++++------ pypush/apns/_util.py | 2 +- pypush/apns/albert.py | 4 ++-- pypush/apns/filters.py | 8 ++++++-- pypush/apns/lifecycle.py | 44 +++++++++++++++++++++------------------- pypush/apns/protocol.py | 23 +++++++-------------- pypush/apns/transport.py | 11 +++++----- pypush/cli/__init__.py | 27 +++++++++++++----------- pypush/cli/proxy.py | 3 +-- tests/test_apns.py | 6 ++---- 11 files changed, 74 insertions(+), 72 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 05602b2..287497e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,4 +34,8 @@ version_file = "pypush/_version.py" [tool.pytest.ini_options] minversion = "6.0" addopts = ["-ra", "-q"] -testpaths = ["tests"] \ No newline at end of file +testpaths = ["tests"] + +[tool.ruff.lint] +select = ["E", "F", "B", "SIM", "I"] +ignore = ["E501", "B010"] \ No newline at end of file diff --git a/pypush/apns/_protocol.py b/pypush/apns/_protocol.py index 0d79a33..bd6c4b5 100644 --- a/pypush/apns/_protocol.py +++ b/pypush/apns/_protocol.py @@ -67,14 +67,14 @@ def from_packet(cls, packet: Packet): ) # Check for extra fields - for field in packet.fields: - if field.id not in [ + for current_field in packet.fields: + if current_field.id not in [ f.metadata["packet_id"] for f in dataclass_fields(cls) if f.metadata is not None and "packet_id" in f.metadata ]: logging.warning( - f"Unexpected field with packet ID {field.id} in packet {packet}" + f"Unexpected field with packet ID {current_field.id} in packet {packet}" ) return cls(**field_values) @@ -122,15 +122,15 @@ def fid( :param byte_len: The length of the field in bytes (for int fields) :param default: The default value of the field """ - if not default == MISSING and not default_factory == MISSING: + if default != MISSING and default_factory != MISSING: raise ValueError("Cannot specify both default and default_factory") - if not default == MISSING: + if default != MISSING: return field( metadata={"packet_id": packet_id, "packet_bytes": byte_len}, default=default, repr=repr, ) - if not default_factory == MISSING: + if default_factory != MISSING: return field( metadata={"packet_id": packet_id, "packet_bytes": byte_len}, default_factory=default_factory, diff --git a/pypush/apns/_util.py b/pypush/apns/_util.py index 30bc634..3564892 100644 --- a/pypush/apns/_util.py +++ b/pypush/apns/_util.py @@ -1,6 +1,6 @@ import logging from contextlib import asynccontextmanager -from typing import Callable, Generic, Optional, Type, TypeVar +from typing import Generic, TypeVar import anyio from anyio.abc import ObjectReceiveStream, ObjectSendStream diff --git a/pypush/apns/albert.py b/pypush/apns/albert.py index 459e79c..3706807 100644 --- a/pypush/apns/albert.py +++ b/pypush/apns/albert.py @@ -96,10 +96,10 @@ async def activate( try: protocol = re.search("(.*)", resp.text).group(1) # type: ignore - except AttributeError: + except AttributeError as e: # Search for error text between and error = re.search("(.*)", resp.text).group(1) # type: ignore - raise Exception(f"Failed to get certificate from Albert: {error}") + raise Exception(f"Failed to get certificate from Albert: {error}") from e protocol = plistlib.loads(protocol.encode("utf-8")) diff --git a/pypush/apns/filters.py b/pypush/apns/filters.py index 148ebd0..63bb784 100644 --- a/pypush/apns/filters.py +++ b/pypush/apns/filters.py @@ -36,5 +36,9 @@ def filter(command: protocol.Command) -> Optional[T]: return filter -ALL = lambda c: c -NONE = lambda _: None +def ALL(c): + return c + + +def NONE(_): + return None diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 1d5f7d1..e094bc1 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -32,7 +32,9 @@ async def create_apns_connection( await conn._connected.wait() yield conn tg.cancel_scope.cancel() # Cancel the task group when the context manager exits - await conn.aclose() # Make sure to close the connection after the task group is cancelled + await ( + conn.aclose() + ) # Make sure to close the connection after the task group is cancelled class Connection: @@ -45,7 +47,6 @@ def __init__( sandbox: bool = False, courier: typing.Optional[str] = None, ): - self.certificate = certificate self.private_key = private_key self.base_token = token @@ -63,7 +64,11 @@ def __init__( self.sandbox = sandbox if courier is None: # Pick a random courier server from 1 to 50 - courier = f"{random.randint(1, 50)}-courier.push.apple.com" if not sandbox else f"{random.randint(1, 10)}-courier.sandbox.push.apple.com" + courier = ( + f"{random.randint(1, 50)}-courier.push.apple.com" + if not sandbox + else f"{random.randint(1, 10)}-courier.sandbox.push.apple.com" + ) logging.debug(f"Using courier: {courier}") self.courier = courier @@ -144,7 +149,7 @@ async def reconnect(self): assert ack.token == self.base_token if not self._connected.is_set(): self._connected.set() - + await self._update_filter() async def aclose(self): @@ -176,8 +181,8 @@ async def _send(self, command: protocol.Command): async with self._send_lock: assert self._conn is not None await self._conn.send(command) - except Exception as e: - logging.warning(f"Error sending command, reconnecting") + except Exception: + logging.warning("Error sending command, reconnecting") await self.reconnect() await self._send(command) @@ -226,26 +231,23 @@ async def notification_stream( ): if token is None: token = self.base_token - async with self._filter([topic]): - async with self._receive_stream( + async with self._filter([topic]), self._receive_stream( + filters.chain( filters.chain( filters.chain( - filters.chain( - filters.cmd(protocol.SendMessageCommand), - lambda c: c if c.token == token else None, - ), - lambda c: (c if c.topic == topic else None), + filters.cmd(protocol.SendMessageCommand), + lambda c: c if c.token == token else None, ), - filter, - ) - ) as stream: - yield stream + lambda c: (c if c.topic == topic else None), + ), + filter, + ) + ) as stream: + yield stream async def ack(self, command: protocol.SendMessageCommand, status: int = 0): await self._send( - protocol.SendMessageAck( - status=status, token=command.token, id=command.id - ) + protocol.SendMessageAck(status=status, token=command.token, id=command.id) ) async def expect_notification( @@ -259,4 +261,4 @@ async def expect_notification( async with self.notification_stream(topic, token, filter) as stream: command = await stream.receive() await self.ack(command) - return command \ No newline at end of file + return command diff --git a/pypush/apns/protocol.py b/pypush/apns/protocol.py index 244cda3..147119c 100644 --- a/pypush/apns/protocol.py +++ b/pypush/apns/protocol.py @@ -87,12 +87,7 @@ class FilterCommand(Command): def _lookup_hashes(self, hashes: Optional[list[bytes]]): return ( - [ - KNOWN_TOPICS_LOOKUP[hash] if hash in KNOWN_TOPICS_LOOKUP else hash - for hash in hashes - ] - if hashes - else [] + [KNOWN_TOPICS_LOOKUP.get(hash, hash) for hash in hashes] if hashes else [] ) @property @@ -183,7 +178,7 @@ def __post_init__(self): ) and not (self._token_topic_1 is not None and self._token_topic_2 is not None): raise ValueError("topic, token, and outgoing must be set.") - if self.outgoing == True: + if self.outgoing is True: assert self.topic and self.token self._token_topic_1 = ( sha1(self.topic.encode()).digest() @@ -191,7 +186,7 @@ def __post_init__(self): else self.topic ) self._token_topic_2 = self.token - elif self.outgoing == False: + elif self.outgoing is False: assert self.topic and self.token self._token_topic_1 = self.token self._token_topic_2 = ( @@ -202,18 +197,14 @@ def __post_init__(self): else: assert self._token_topic_1 and self._token_topic_2 if len(self._token_topic_1) == 20: # SHA1 hash, topic - self.topic = ( - KNOWN_TOPICS_LOOKUP[self._token_topic_1] - if self._token_topic_1 in KNOWN_TOPICS_LOOKUP - else self._token_topic_1 + self.topic = KNOWN_TOPICS_LOOKUP.get( + self._token_topic_1, self._token_topic_1 ) self.token = self._token_topic_2 self.outgoing = True else: - self.topic = ( - KNOWN_TOPICS_LOOKUP[self._token_topic_2] - if self._token_topic_2 in KNOWN_TOPICS_LOOKUP - else self._token_topic_2 + self.topic = KNOWN_TOPICS_LOOKUP.get( + self._token_topic_2, self._token_topic_2 ) self.token = self._token_topic_1 self.outgoing = False diff --git a/pypush/apns/transport.py b/pypush/apns/transport.py index bdf54ae..d0de86b 100644 --- a/pypush/apns/transport.py +++ b/pypush/apns/transport.py @@ -67,10 +67,7 @@ async def create_courier_connection( context = ssl.create_default_context() context.set_alpn_protocols(ALPN) - if sandbox: - sni = "courier.sandbox.push.apple.com" - else: - sni = "courier.push.apple.com" + sni = "courier.sandbox.push.apple.com" if sandbox else "courier.push.apple.com" # TODO: Verify courier certificate context.check_hostname = False @@ -78,7 +75,11 @@ async def create_courier_connection( return PacketStream( await anyio.connect_tcp( - courier, COURIER_PORT, ssl_context=context, tls_standard_compatible=False, tls_hostname=sni + courier, + COURIER_PORT, + ssl_context=context, + tls_standard_compatible=False, + tls_hostname=sni, ) ) diff --git a/pypush/cli/__init__.py b/pypush/cli/__init__.py index cc2523a..1495dd0 100644 --- a/pypush/cli/__init__.py +++ b/pypush/cli/__init__.py @@ -1,13 +1,15 @@ +import contextlib import logging +from asyncio import CancelledError +import anyio import typer from rich.logging import RichHandler from typing_extensions import Annotated -from . import proxy as _proxy from pypush import apns -import anyio -from asyncio import CancelledError + +from . import proxy as _proxy logging.basicConfig(level=logging.INFO, handlers=[RichHandler()], format="%(message)s") @@ -25,10 +27,8 @@ def proxy( Attach requires SIP to be disabled and to be running as root """ - try: + with contextlib.suppress(CancelledError): _proxy.main(attach) - except CancelledError: - pass @app.command() @@ -42,20 +42,23 @@ def notifications( Connect to the APNs courier and listen for app notifications on the given topic """ logging.getLogger("httpx").setLevel(logging.WARNING) - try: + with contextlib.suppress(CancelledError): anyio.run(notifications_async, topic, sandbox) - except CancelledError: - pass + async def notifications_async(topic: str, sandbox: bool): async with apns.create_apns_connection( - *await apns.activate(), courier="1-courier.sandbox.push.apple.com" if sandbox else "1-courier.push.apple.com" + *await apns.activate(), + courier="1-courier.sandbox.push.apple.com" + if sandbox + else "1-courier.push.apple.com", ) as connection: - token = await connection.mint_scoped_token(topic) async with connection.notification_stream(topic, token) as stream: - logging.info(f"Listening for notifications on topic {topic} ({'sandbox' if sandbox else 'production'})") + logging.info( + f"Listening for notifications on topic {topic} ({'sandbox' if sandbox else 'production'})" + ) logging.info(f"Token: {token.hex()}") async for notification in stream: diff --git a/pypush/cli/proxy.py b/pypush/cli/proxy.py index e611112..8c43bc4 100644 --- a/pypush/cli/proxy.py +++ b/pypush/cli/proxy.py @@ -2,7 +2,6 @@ import logging import ssl import tempfile -from typing import Optional import anyio import anyio.abc @@ -12,7 +11,7 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.hashes import SHA256 -from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat +from cryptography.hazmat.primitives.serialization import Encoding # from pypush import apns from pypush.apns import protocol, transport diff --git a/tests/test_apns.py b/tests/test_apns.py index a3b6232..501e862 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -21,11 +21,10 @@ async def test_activate(): @pytest.mark.asyncio async def test_lifecycle_2(): - async with apns.create_apns_connection( - certificate, key - ) as connection: + async with apns.create_apns_connection(certificate, key) as _: pass + ASSETS_DIR = Path(__file__).parent / "assets" @@ -51,7 +50,6 @@ async def test_scoped_token(): async with apns.create_apns_connection( *await apns.activate(), sandbox=True ) as connection: - token = await connection.mint_scoped_token("dev.jjtech.pypush.tests") test_message = f"test-message-{uuid.uuid4().hex}" From 954b59a87f16bf251f08127b3c074ace756332eb Mon Sep 17 00:00:00 2001 From: JJTech Date: Sun, 19 May 2024 13:28:53 -0400 Subject: [PATCH 16/19] ci: enforce ruff lints --- .github/workflows/ruff.yml | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .github/workflows/ruff.yml diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..b268138 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,8 @@ +name: Ruff +on: [push, pull_request] +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: chartboost/ruff-action@v1 From f2790c0d257e1af35c95db30cc17122eab1ff4ab Mon Sep 17 00:00:00 2001 From: JJTech Date: Sun, 19 May 2024 13:40:06 -0400 Subject: [PATCH 17/19] ci: enforce pyright typing --- .github/workflows/pyright.yml | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .github/workflows/pyright.yml diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml new file mode 100644 index 0000000..dd114c3 --- /dev/null +++ b/.github/workflows/pyright.yml @@ -0,0 +1,8 @@ +name: Pyright +on: [push, pull_request] +jobs: + pyright: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: jakebailey/pyright-action@v2 From 4d40ed0cb713f73e6f4d556eaa4a622a87dcd054 Mon Sep 17 00:00:00 2001 From: JJTech Date: Sun, 19 May 2024 13:43:03 -0400 Subject: [PATCH 18/19] ci: use venv with deps for pyright --- .github/workflows/pyright.yml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml index dd114c3..bfad777 100644 --- a/.github/workflows/pyright.yml +++ b/.github/workflows/pyright.yml @@ -5,4 +5,14 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + cache: 'pip' + + - run: | + python -m venv .venv + source .venv/bin/activate + pip install -e '.[test,cli]' + + - run: echo "$PWD/.venv/bin" >> $GITHUB_PATH - uses: jakebailey/pyright-action@v2 From 8d59b7ea057fdc20ed258518350008818cc43719 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Sun, 19 May 2024 14:45:14 -0400 Subject: [PATCH 19/19] apns: async base_token property makes awaiting _connected an implementation detail --- pypush/apns/lifecycle.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index e094bc1..23d3f94 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -29,7 +29,6 @@ async def create_apns_connection( conn = Connection( tg, certificate, private_key, token, sandbox, courier ) # Await connected for first time here, so that base token is set - await conn._connected.wait() yield conn tg.cancel_scope.cancel() # Cancel the task group when the context manager exits await ( @@ -49,11 +48,11 @@ def __init__( ): self.certificate = certificate self.private_key = private_key - self.base_token = token + self._base_token = token self._filters: dict[str, int] = {} # topic -> use count - self._connected = anyio.Event() # Set when the connection is first established + self._connected = anyio.Event() # Only use for base_token property self._conn = None self._tg = task_group @@ -75,6 +74,13 @@ def __init__( self._tg.start_soon(self.reconnect) self._tg.start_soon(self._ping_task) + @property + async def base_token(self) -> bytes: + if self._base_token is None: + await self._connected.wait() + assert self._base_token is not None + return self._base_token + async def _receive_task(self): assert self._conn is not None async for command in self._conn: @@ -114,7 +120,7 @@ async def reconnect(self): ) await conn.send( protocol.ConnectCommand( - push_token=self.base_token, + push_token=self._base_token, state=1, flags=65, # 69 certificate=cert, @@ -133,8 +139,8 @@ async def reconnect(self): lambda c: ( c if ( - c.token == self.base_token - if self.base_token is not None + c.token == self._base_token + if self._base_token is not None else True ) else None @@ -143,10 +149,10 @@ async def reconnect(self): ) logging.debug(f"Connected with ack: {ack}") assert ack.status == 0 - if self.base_token is None: - self.base_token = ack.token + if self._base_token is None: + self._base_token = ack.token else: - assert ack.token == self.base_token + assert ack.token == self._base_token if not self._connected.is_set(): self._connected.set() @@ -187,10 +193,9 @@ async def _send(self, command: protocol.Command): await self._send(command) async def _update_filter(self): - assert self.base_token is not None await self._send( protocol.FilterCommand( - token=self.base_token, + token=await self.base_token, enabled_topic_hashes=[ sha1(topic.encode()).digest() for topic in self._filters ], @@ -199,7 +204,6 @@ async def _update_filter(self): @asynccontextmanager async def _filter(self, topics: list[str]): - assert self.base_token is not None for topic in topics: self._filters[topic] = self._filters.get(topic, 0) + 1 await self._update_filter() @@ -212,9 +216,8 @@ async def _filter(self, topics: list[str]): async def mint_scoped_token(self, topic: str) -> bytes: topic_hash = sha1(topic.encode()).digest() - assert self.base_token is not None await self._send( - protocol.ScopedTokenCommand(token=self.base_token, topic=topic_hash) + protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash) ) ack = await self._receive(filters.cmd(protocol.ScopedTokenAck)) assert ack.status == 0 @@ -230,7 +233,7 @@ async def notification_stream( ] = filters.ALL, ): if token is None: - token = self.base_token + token = await self.base_token async with self._filter([topic]), self._receive_stream( filters.chain( filters.chain(