Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse ZDO Initializers to create Endpoint objects on EZSP device #599

Merged
merged 7 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions bellows/zigbee/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class ControllerApplication(zigpy.application.ControllerApplication):
def __init__(self, config: dict):
super().__init__(config)
self._ctrl_event = asyncio.Event()
self._created_device_endpoints: list[zdo_t.SimpleDescriptor] = []
self._ezsp = None
self._multicast = None
self._mfg_id_task: asyncio.Task | None = None
Expand Down Expand Up @@ -116,9 +117,12 @@ async def add_endpoint(self, descriptor: zdo_t.SimpleDescriptor) -> None:
descriptor.input_clusters,
descriptor.output_clusters,
)

if status != t.EmberStatus.SUCCESS:
raise StackAlreadyRunning()

self._created_device_endpoints.append(descriptor)

async def cleanup_tc_link_key(self, ieee: t.EUI64) -> None:
"""Remove tc link_key for the given device."""
(index,) = await self._ezsp.findKeyTableEntry(ieee, True)
Expand Down Expand Up @@ -150,6 +154,8 @@ async def connect(self) -> None:
raise

self._ezsp = ezsp

self._created_device_endpoints.clear()
await self.register_endpoints()

async def _ensure_network_running(self) -> bool:
Expand Down Expand Up @@ -198,10 +204,15 @@ async def start_network(self):
ezsp.add_callback(self.ezsp_callback_handler)
self.controller_event.set()

group_membership = {}

try:
db_device = self.get_device(ieee=self.state.node_info.ieee)
except KeyError:
db_device = None
pass
else:
if 1 in db_device.endpoints:
group_membership = db_device.endpoints[1].member_of

ezsp_device = zigpy.device.Device(
application=self,
Expand All @@ -210,15 +221,18 @@ async def start_network(self):
)
self.devices[self.state.node_info.ieee] = ezsp_device

# The coordinator device does not respond to attribute reads
ezsp_device.endpoints[1] = EZSPEndpoint(ezsp_device, 1)
ezsp_device.model = ezsp_device.endpoints[1].model
ezsp_device.manufacturer = ezsp_device.endpoints[1].manufacturer
# The coordinator device does not respond to attribute reads so we have to
# divine the internal NCP state.
for zdo_desc in self._created_device_endpoints:
ep = EZSPEndpoint(ezsp_device, zdo_desc.endpoint, zdo_desc)
ezsp_device.endpoints[zdo_desc.endpoint] = ep
ezsp_device.model = ep.model
ezsp_device.manufacturer = ep.manufacturer

await ezsp_device.schedule_initialize()

# Group membership is stored in the database for EZSP coordinators
if db_device is not None and 1 in db_device.endpoints:
ezsp_device.endpoints[1].member_of.update(db_device.endpoints[1].member_of)
ezsp_device.endpoints[1].member_of.update(group_membership)

self._multicast = bellows.multicast.Multicast(ezsp)
await self._multicast.startup(ezsp_device)
Expand Down
41 changes: 35 additions & 6 deletions bellows/zigbee/device.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,51 @@
from __future__ import annotations

import logging
import typing

import zigpy.device
import zigpy.endpoint
import zigpy.util
import zigpy.zdo
import zigpy.profiles.zgp
import zigpy.profiles.zha
import zigpy.profiles.zll
import zigpy.zdo.types as zdo_t

import bellows.types as t

if typing.TYPE_CHECKING:
import zigpy.application # pragma: no cover

LOGGER = logging.getLogger(__name__)

PROFILE_TO_DEVICE_TYPE = {
zigpy.profiles.zha.PROFILE_ID: zigpy.profiles.zha.DeviceType,
zigpy.profiles.zll.PROFILE_ID: zigpy.profiles.zll.DeviceType,
zigpy.profiles.zgp.PROFILE_ID: zigpy.profiles.zgp.DeviceType,
}


class EZSPEndpoint(zigpy.endpoint.Endpoint):
def __init__(
self,
device: zigpy.device.Device,
endpoint_id: int,
descriptor: zdo_t.SimpleDescriptor,
) -> None:
super().__init__(device, endpoint_id)

self.profile_id = descriptor.profile

if self.profile_id in PROFILE_TO_DEVICE_TYPE:
self.device_type = PROFILE_TO_DEVICE_TYPE[self.profile_id](
descriptor.device_type
)
else:
self.device_type = descriptor.device_type

for cluster in descriptor.input_clusters:
self.add_input_cluster(cluster)

for cluster in descriptor.output_clusters:
self.add_output_cluster(cluster)

self.status = zigpy.endpoint.Status.ZDO_INIT

@property
def manufacturer(self) -> str:
"""Manufacturer."""
Expand Down
90 changes: 52 additions & 38 deletions tests/test_application.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
import logging
from unittest.mock import AsyncMock, MagicMock, PropertyMock, call, patch, sentinel

Expand Down Expand Up @@ -114,7 +115,6 @@ def aps():
return f


@patch("zigpy.device.Device._initialize", new=AsyncMock())
def _create_app_for_startup(
app,
nwk_type,
Expand Down Expand Up @@ -206,6 +206,14 @@ async def mock_leave(*args, **kwargs):
),
]
)
ezsp_mock.getMulticastTableEntry = AsyncMock(
return_value=[
t.EmberStatus.SUCCESS,
t.EmberMulticastTableEntry(multicastId=0x0000, endpoint=0, networkIndex=0),
]
)
ezsp_mock.setMulticastTableEntry = AsyncMock(return_value=[t.EmberStatus.SUCCESS])

app.permit = AsyncMock()

def form_network():
Expand All @@ -220,10 +228,11 @@ def form_network():
return ezsp_mock


async def _test_startup(
@contextlib.contextmanager
def mock_for_startup(
app,
nwk_type,
ieee,
nwk_type=t.EmberNodeType.COORDINATOR,
auto_form=False,
init=0,
ezsp_version=4,
Expand All @@ -234,10 +243,25 @@ async def _test_startup(
app, nwk_type, ieee, auto_form, init, ezsp_version, board_info, network_state
)

p1 = patch("bellows.ezsp.EZSP", return_value=ezsp_mock)
p2 = patch.object(bellows.multicast.Multicast, "startup")
with patch("bellows.ezsp.EZSP", return_value=ezsp_mock), patch(
"zigpy.device.Device._initialize", new=AsyncMock()
):
yield ezsp_mock


with p1, p2 as multicast_mock:
async def _test_startup(
app,
nwk_type,
ieee,
auto_form=False,
init=0,
ezsp_version=4,
board_info=True,
network_state=t.EmberNetworkStatus.JOINED_NETWORK,
):
with mock_for_startup(
app, ieee, nwk_type, auto_form, init, ezsp_version, board_info, network_state
) as ezsp_mock:
await app.startup(auto_form=auto_form)

if ezsp_version > 6:
Expand All @@ -247,7 +271,6 @@ async def _test_startup(

assert ezsp_mock.write_config.call_count == 1
assert ezsp_mock.addEndpoint.call_count >= 2
assert multicast_mock.await_count == 1


async def test_startup(app, ieee):
Expand Down Expand Up @@ -1166,7 +1189,7 @@ async def test_shutdown(app):
@pytest.fixture
def coordinator(app, ieee):
dev = zigpy.device.Device(app, ieee, 0x0000)
dev.endpoints[1] = bellows.zigbee.device.EZSPEndpoint(dev, 1)
dev.endpoints[1] = bellows.zigbee.device.EZSPEndpoint(dev, 1, MagicMock())
dev.model = dev.endpoints[1].model
dev.manufacturer = dev.endpoints[1].manufacturer

Expand Down Expand Up @@ -1505,42 +1528,32 @@ async def test_ensure_network_running_not_joined_success(app):

async def test_startup_coordinator_existing_groups_joined(app, ieee):
"""Coordinator joins groups loaded from the database."""
with mock_for_startup(app, ieee) as ezsp_mock:
await app.connect()

app._ensure_network_running = AsyncMock()
app._ezsp.update_policies = AsyncMock()
app.load_network_info = AsyncMock()
app.state.node_info.ieee = ieee

db_device = app.add_device(ieee, 0x0000)
db_ep = db_device.add_endpoint(1)

app.groups.add_group(0x1234, "Group Name", suppress_event=True)
app.groups[0x1234].add_member(db_ep, suppress_event=True)
db_device = app.add_device(ieee, 0x0000)
db_ep = db_device.add_endpoint(1)

p1 = patch.object(bellows.multicast.Multicast, "_initialize")
p2 = patch.object(bellows.multicast.Multicast, "subscribe")
app.groups.add_group(0x1234, "Group Name", suppress_event=True)
app.groups[0x1234].add_member(db_ep, suppress_event=True)

with p1 as p1, p2 as p2:
await app.start_network()

p2.assert_called_once_with(0x1234)
assert ezsp_mock.setMulticastTableEntry.mock_calls == [
call(
0,
t.EmberMulticastTableEntry(multicastId=0x1234, endpoint=1, networkIndex=0),
)
]


async def test_startup_new_coordinator_no_groups_joined(app, ieee):
"""Coordinator freshy added to the database has no groups to join."""

app._ensure_network_running = AsyncMock()
app._ezsp.update_policies = AsyncMock()
app.load_network_info = AsyncMock()
app.state.node_info.ieee = ieee

p1 = patch.object(bellows.multicast.Multicast, "_initialize")
p2 = patch.object(bellows.multicast.Multicast, "subscribe")

with p1 as p1, p2 as p2:
with mock_for_startup(app, ieee) as ezsp_mock:
await app.connect()
await app.start_network()

p2.assert_not_called()
assert ezsp_mock.setMulticastTableEntry.mock_calls == []


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1628,22 +1641,23 @@ async def test_connect_failure(
assert len(ezsp_mock.close.mock_calls) == 1


async def test_repair_tclk_partner_ieee(app: ControllerApplication) -> None:
async def test_repair_tclk_partner_ieee(
app: ControllerApplication, ieee: zigpy_t.EUI64
) -> None:
"""Test that EZSP is reset after repairing TCLK."""
app._ensure_network_running = AsyncMock()
app._reset = AsyncMock()
app.load_network_info = AsyncMock()

with patch(
with mock_for_startup(app, ieee), patch(
"bellows.zigbee.repairs.fix_invalid_tclk_partner_ieee",
AsyncMock(return_value=False),
):
await app.connect()
await app.start_network()

assert len(app._reset.mock_calls) == 0
app._reset.reset_mock()

with patch(
with mock_for_startup(app, ieee), patch(
"bellows.zigbee.repairs.fix_invalid_tclk_partner_ieee",
AsyncMock(return_value=True),
):
Expand Down