diff --git a/CHANGELOG.md b/CHANGELOG.md index 0289ceed..e0a1945d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ Sections ### Developers --> +## [4.9.1] - 2023-10-25 + +- Fix handling of explict close. [#467](https://github.com/ikalchev/HAP-python/pull/467) + ## [4.9.0] - 2023-10-15 - Hashing of accessories no longer includes their values, resulting in more reliable syncs between diff --git a/pyhap/accessory.py b/pyhap/accessory.py index f8a95f57..902fc7f7 100644 --- a/pyhap/accessory.py +++ b/pyhap/accessory.py @@ -241,7 +241,9 @@ def to_HAP(self, include_value: bool = True) -> Dict[str, Any]: """ return { HAP_REPR_AID: self.aid, - HAP_REPR_SERVICES: [s.to_HAP(include_value=include_value) for s in self.services], + HAP_REPR_SERVICES: [ + s.to_HAP(include_value=include_value) for s in self.services + ], } def setup_message(self): @@ -391,7 +393,10 @@ def to_HAP(self, include_value: bool = True) -> List[Dict[str, Any]]: .. seealso:: Accessory.to_HAP """ - return [acc.to_HAP(include_value=include_value) for acc in (super(), *self.accessories.values())] + return [ + acc.to_HAP(include_value=include_value) + for acc in (super(), *self.accessories.values()) + ] def get_characteristic(self, aid: int, iid: int) -> Optional["Characteristic"]: """.. seealso:: Accessory.to_HAP""" diff --git a/pyhap/const.py b/pyhap/const.py index e4c72e1c..9f081eef 100644 --- a/pyhap/const.py +++ b/pyhap/const.py @@ -1,7 +1,7 @@ """This module contains constants used by other modules.""" MAJOR_VERSION = 4 MINOR_VERSION = 9 -PATCH_VERSION = 0 +PATCH_VERSION = 1 __short_version__ = f"{MAJOR_VERSION}.{MINOR_VERSION}" __version__ = f"{__short_version__}.{PATCH_VERSION}" REQUIRED_PYTHON_VER = (3, 7) diff --git a/pyhap/hap_protocol.py b/pyhap/hap_protocol.py index b54db5e1..0f51dec9 100644 --- a/pyhap/hap_protocol.py +++ b/pyhap/hap_protocol.py @@ -46,7 +46,7 @@ def __init__( connections: Dict[str, "HAPServerProtocol"], accessory_driver: "AccessoryDriver", ) -> None: - self.loop: asyncio.AbstractEventLoop = loop + self.loop = loop self.conn = h11.Connection(h11.SERVER) self.connections = connections self.accessory_driver = accessory_driver @@ -55,7 +55,7 @@ def __init__( self.transport: Optional[asyncio.Transport] = None self.request: Optional[h11.Request] = None - self.request_body: Optional[bytes] = None + self.request_body: List[bytes] = [] self.response: Optional[HAPResponse] = None self.last_activity: Optional[float] = None @@ -246,27 +246,33 @@ def _process_one_event(self) -> bool: logger.debug( "%s (%s): h11 Event: %s", self.peername, self.handler.client_uuid, event ) - if event in (h11.NEED_DATA, h11.ConnectionClosed): + if event is h11.NEED_DATA: return False if event is h11.PAUSED: self.conn.start_next_cycle() return True - if isinstance(event, h11.Request): + event_type = type(event) + if event_type is h11.ConnectionClosed: + return False + + if event_type is h11.Request: self.request = event - self.request_body = b"" + self.request_body = [] return True - if isinstance(event, h11.Data): - self.request_body += event.data + if event_type is h11.Data: + if TYPE_CHECKING: + assert isinstance(event, h11.Data) # nosec + self.request_body.append(event.data) return True - if isinstance(event, h11.EndOfMessage): - response = self.handler.dispatch(self.request, bytes(self.request_body)) + if event_type is h11.EndOfMessage: + response = self.handler.dispatch(self.request, b"".join(self.request_body)) self._process_response(response) self.request = None - self.request_body = None + self.request_body = [] return True return self._handle_invalid_conn_state(f"Unexpected event: {event}") diff --git a/pyhap/hap_server.py b/pyhap/hap_server.py index 5016355e..1e8414a8 100644 --- a/pyhap/hap_server.py +++ b/pyhap/hap_server.py @@ -3,12 +3,17 @@ The HAPServer is the point of contact to and from the world. """ +import asyncio import logging import time +from typing import TYPE_CHECKING, Dict, Optional, Tuple from .hap_protocol import HAPServerProtocol from .util import callback +if TYPE_CHECKING: + from .accessory_driver import AccessoryDriver + logger = logging.getLogger(__name__) IDLE_CONNECTION_CHECK_INTERVAL_SECONDS = 120 @@ -28,17 +33,18 @@ class HAPServer: implements exclusive access to the send methods. """ - def __init__(self, addr_port, accessory_handler): + def __init__( + self, addr_port: Tuple[str, int], accessory_handler: "AccessoryDriver" + ) -> None: """Create a HAP Server.""" self._addr_port = addr_port - self.connections = {} # (address, port): socket + self.connections: Dict[Tuple[str, int], HAPServerProtocol] = {} self.accessory_handler = accessory_handler - self.server = None - self._serve_task = None - self._connection_cleanup = None - self.loop = None + self.server: Optional[asyncio.Server] = None + self._connection_cleanup: Optional[asyncio.TimerHandle] = None + self.loop: Optional[asyncio.AbstractEventLoop] = None - async def async_start(self, loop): + async def async_start(self, loop: asyncio.AbstractEventLoop) -> None: """Start the http-hap server.""" self.loop = loop self.server = await loop.create_server( @@ -49,7 +55,7 @@ async def async_start(self, loop): self.async_cleanup_connections() @callback - def async_cleanup_connections(self): + def async_cleanup_connections(self) -> None: """Cleanup stale connections.""" now = time.time() for hap_proto in list(self.connections.values()): @@ -59,7 +65,7 @@ def async_cleanup_connections(self): ) @callback - def async_stop(self): + def async_stop(self) -> None: """Stop the server. This method must be run in the event loop. @@ -70,10 +76,12 @@ def async_stop(self): self.server.close() self.connections.clear() - def push_event(self, data, client_addr, immediate=False): + def push_event( + self, data: bytes, client_addr: Tuple[str, int], immediate: bool = False + ) -> bool: """Queue an event to the current connection with the provided data. - :param data: The charateristic changes + :param data: The characteristic changes :type data: dict :param client_addr: A client (address, port) tuple to which to send the data. diff --git a/pyhap/iid_manager.py b/pyhap/iid_manager.py index 855680e6..e571342c 100644 --- a/pyhap/iid_manager.py +++ b/pyhap/iid_manager.py @@ -1,5 +1,12 @@ """Module for the IIDManager class.""" import logging +from typing import TYPE_CHECKING, Dict, Optional, Union + +if TYPE_CHECKING: + from .characteristic import Characteristic + from .service import Service + + ServiceOrCharType = Union[Service, Characteristic] logger = logging.getLogger(__name__) @@ -7,13 +14,13 @@ class IIDManager: """Maintains a mapping between Service/Characteristic objects and IIDs.""" - def __init__(self): + def __init__(self) -> None: """Initialize an empty instance.""" self.counter = 0 - self.iids = {} - self.objs = {} + self.iids: Dict["ServiceOrCharType", int] = {} + self.objs: Dict[int, "ServiceOrCharType"] = {} - def assign(self, obj): + def assign(self, obj: "ServiceOrCharType") -> None: """Assign an IID to given object. Print warning if already assigned. :param obj: The object that will be assigned an IID. @@ -32,7 +39,7 @@ def assign(self, obj): self.iids[obj] = iid self.objs[iid] = obj - def get_iid_for_obj(self, obj): + def get_iid_for_obj(self, obj: "ServiceOrCharType") -> int: """Get the IID for the given object. Override this method to provide custom IID assignment. @@ -40,15 +47,15 @@ def get_iid_for_obj(self, obj): self.counter += 1 return self.counter - def get_obj(self, iid): + def get_obj(self, iid: int) -> "ServiceOrCharType": """Get the object that is assigned the given IID.""" return self.objs.get(iid) - def get_iid(self, obj): + def get_iid(self, obj: "ServiceOrCharType") -> int: """Get the IID assigned to the given object.""" return self.iids.get(obj) - def remove_obj(self, obj): + def remove_obj(self, obj: "ServiceOrCharType") -> Optional[int]: """Remove an object from the IID list.""" iid = self.iids.pop(obj, None) if iid is None: @@ -57,7 +64,7 @@ def remove_obj(self, obj): del self.objs[iid] return iid - def remove_iid(self, iid): + def remove_iid(self, iid: int) -> Optional["ServiceOrCharType"]: """Remove an object with an IID from the IID list.""" obj = self.objs.pop(iid, None) if obj is None: diff --git a/tests/test_hap_protocol.py b/tests/test_hap_protocol.py index 0d2a58c3..06bd71b9 100644 --- a/tests/test_hap_protocol.py +++ b/tests/test_hap_protocol.py @@ -8,9 +8,30 @@ from pyhap import hap_handler, hap_protocol from pyhap.accessory import Accessory, Bridge +from pyhap.accessory_driver import AccessoryDriver from pyhap.hap_handler import HAPResponse +class MockTransport(asyncio.Transport): # pylint: disable=abstract-method + """A mock transport.""" + + _is_closing: bool = False + + def set_write_buffer_limits(self, high=None, low=None): + """Set the write buffer limits.""" + + def write_eof(self) -> None: + """Write EOF to the stream.""" + + def close(self) -> None: + """Close the stream.""" + self._is_closing = True + + def is_closing(self) -> bool: + """Return True if the transport is closing or closed.""" + return self._is_closing + + class MockHAPCrypto: """Mock HAPCrypto that only returns plaintext.""" @@ -734,3 +755,42 @@ async def test_does_not_timeout(driver): assert writer.call_args_list[0][0][0].startswith(b"HTTP/1.1 200 OK\r\n") is True hap_proto.check_idle(time.time()) assert hap_proto_close.called is False + + +def test_explicit_close(driver: AccessoryDriver): + """Test an explicit connection close.""" + loop = MagicMock() + + transport = MockTransport() + connections = {} + + acc = Accessory(driver, "TestAcc", aid=1) + assert acc.aid == 1 + service = acc.driver.loader.get_service("TemperatureSensor") + acc.add_service(service) + driver.add_accessory(acc) + + hap_proto = hap_protocol.HAPServerProtocol(loop, connections, driver) + hap_proto.connection_made(transport) + + hap_proto.hap_crypto = MockHAPCrypto() + hap_proto.handler.is_encrypted = True + assert hap_proto.transport.is_closing() is False + + with patch.object(hap_proto.transport, "write") as writer: + hap_proto.data_received( + b"GET /characteristics?id=3762173001.7 HTTP/1.1\r\nHost: HASS\\032Bridge\\032YPHW\\032B223AD._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long + ) + hap_proto.data_received( + b"GET /characteristics?id=1.5 HTTP/1.1\r\nConnection: close\r\nHost: HASS\\032Bridge\\032YPHW\\032B223AD._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long + ) + + assert b"Content-Length:" in writer.call_args_list[0][0][0] + assert b"Transfer-Encoding: chunked\r\n\r\n" not in writer.call_args_list[0][0][0] + assert b"-70402" in writer.call_args_list[0][0][0] + + assert b"Content-Length:" in writer.call_args_list[1][0][0] + assert b"Transfer-Encoding: chunked\r\n\r\n" not in writer.call_args_list[1][0][0] + assert b"TestAcc" in writer.call_args_list[1][0][0] + + assert hap_proto.transport.is_closing() is True