Skip to content

Commit

Permalink
Rework SubDevices Connection
Browse files Browse the repository at this point in the history
  • Loading branch information
xZetsubou committed Oct 16, 2023
1 parent b8d1a95 commit 99ddac0
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 135 deletions.
40 changes: 27 additions & 13 deletions custom_components/localtuya/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
DATA_DISCOVERY,
DOMAIN,
TUYA_DEVICES,
CONF_NODE_ID,
)
from .discovery import TuyaDiscovery

Expand Down Expand Up @@ -237,15 +238,25 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
res = await tuya_api.async_get_devices_list()
hass.data[DOMAIN][entry.entry_id][DATA_CLOUD] = tuya_api

async def setup_entities(device_ids):
async def setup_entities(devices: dict):
platforms = set()
for dev_id in device_ids:
for dev_id, config in devices.items():
host = config.get(CONF_HOST)
entities = entry.data[CONF_DEVICES][dev_id][CONF_ENTITIES]
platforms = platforms.union(
set(entity[CONF_PLATFORM] for entity in entities)
)

hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES][dev_id] = TuyaDevice(
if node_id := config.get(CONF_NODE_ID):
# Setup sub device as gateway if no gateway not exist.
if host not in hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES]:
hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES][host] = TuyaDevice(
hass, entry, dev_id, True
)

host = f"{host}_{node_id}"

hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES][host] = TuyaDevice(
hass, entry, dev_id
)

Expand All @@ -257,16 +268,18 @@ async def setup_entities(device_ids):
device.async_connect()
for device in hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES].values()
]
try:
await asyncio.wait_for(asyncio.gather(*connect_task), 1)
except:
# If there is device that isn't connected to network it will return failed Initialization.
...
await asyncio.wait_for(asyncio.gather(*connect_task), 5)
# await asyncio.gather(*connect_task)
# try:
# await asyncio.wait_for(asyncio.gather(*connect_task), 1)
# except:
# # If there is device that isn't connected to network it will return failed Initialization.
# ...

await setup_entities(entry.data[CONF_DEVICES])

await setup_entities(entry.data[CONF_DEVICES].keys())
# callback back to unsub listener
unsub_listener = entry.add_update_listener(update_listener)

hass.data[DOMAIN][entry.entry_id].update({UNSUB_LISTENER: unsub_listener})

# Add reconnect trigger every 1mins to reconnect if device not connected.
Expand Down Expand Up @@ -332,7 +345,8 @@ async def async_remove_config_entry_device(
)
return True

await hass.data[DOMAIN][config_entry.entry_id][TUYA_DEVICES][dev_id].close()
host = config_entry.data[CONF_DEVICES][dev_id][CONF_HOST]
await hass.data[DOMAIN][config_entry.entry_id][TUYA_DEVICES][host].close()

new_data = config_entry.data.copy()
new_data[CONF_DEVICES].pop(dev_id)
Expand All @@ -353,9 +367,9 @@ def reconnectTask(hass: HomeAssistant, entry: ConfigEntry):

async def _async_reconnect(now):
"""Try connecting to devices not already connected to."""
for devID, dev in hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES].items():
for host, dev in hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES].items():
if not dev.connected:
hass.create_task(dev.async_connect())
hass.async_create_task(dev.async_connect())

hass.data[DOMAIN][entry.entry_id][RECONNECT_TASK] = async_track_time_interval(
hass, _async_reconnect, RECONNECT_INTERVAL
Expand Down
151 changes: 99 additions & 52 deletions custom_components/localtuya/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,20 @@ async def async_setup_entry(
entities = []

for dev_id in config_entry.data[CONF_DEVICES]:
# entities_to_setup = prepare_setup_entities(
# hass, config_entry.data[dev_id], domain
# )
dev_entry = config_entry.data[CONF_DEVICES][dev_id]
dev_entry: dict = config_entry.data[CONF_DEVICES][dev_id]
host = dev_entry.get(CONF_HOST)

entities_to_setup = [
entity
for entity in dev_entry[CONF_ENTITIES]
if entity[CONF_PLATFORM] == domain
]

if entities_to_setup:
tuyainterface = hass.data[DOMAIN][config_entry.entry_id][TUYA_DEVICES][
dev_id
]
if node_id := dev_entry.get(CONF_NODE_ID):
host = f"{host}_{node_id}"

tuyainterface = hass.data[DOMAIN][config_entry.entry_id][TUYA_DEVICES][host]

dps_config_fields = list(get_dps_for_platform(flow_schema))

Expand Down Expand Up @@ -146,15 +146,25 @@ def async_config_entry_by_device_id(hass, device_id):


class TuyaDevice(pytuya.TuyaListener, pytuya.ContextualLogger):
"""Cache wrapper for pytuya.TuyaInterface."""

def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry, dev_id: str):
"""Cache wrapper for pytuya.TuyaInterface, and Sub Devices."""

def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
dev_id: str,
gateway=False,
):
"""Initialize the cache."""
super().__init__()
self._hass = hass
self._config_entry = config_entry
self._dev_config_entry: dict = config_entry.data[CONF_DEVICES][dev_id].copy()
self._interface = None
# For SubDevices
self._node_id: str = not gateway and self._dev_config_entry.get(CONF_NODE_ID)
self._gwateway: TuyaDevice = None
self._sub_devices = {}
self._status = {}
self.dps_to_request = {}
self._is_closing = False
Expand Down Expand Up @@ -195,6 +205,21 @@ def connected(self):
"""Return if connected to device."""
return self._interface is not None

def get_gateway(self):
"""Search for gateway device if not exist create one"""
if not self._node_id:
return
entry_id = self._config_entry.entry_id
node_host = self._dev_config_entry.get(CONF_HOST)
devices: dict = self._hass.data[DOMAIN][entry_id][TUYA_DEVICES]

# Sub to gateway.
gateway = devices.get(node_host)
gateway._sub_devices[self._node_id] = self
for dev_ip, device in devices.items():
if dev_ip == node_host:
return device

async def async_connect(self):
"""Connect to device if not already connected."""
# self.info("async_connect: %d %r %r", self._is_closing, self._connect_task, self._interface)
Expand All @@ -204,26 +229,33 @@ async def async_connect(self):

async def _make_connection(self):
"""Subscribe localtuya entity events."""
self.info("Trying to connect to %s...", self._dev_config_entry[CONF_HOST])
self._connect_task = True
host = self._dev_config_entry.get(CONF_HOST)
name = self._dev_config_entry.get(CONF_FRIENDLY_NAME)

try:
self._interface = await pytuya.connect(
self._dev_config_entry[CONF_HOST],
self._dev_config_entry[CONF_DEVICE_ID],
self._local_key,
float(self._dev_config_entry[CONF_PROTOCOL_VERSION]),
self._dev_config_entry.get(CONF_ENABLE_DEBUG, False),
self._dev_config_entry.get(CONF_NODE_ID, None),
self,
)
if self._node_id:
gateway = self._gwateway
self._gwateway = self.get_gateway() if not gateway else gateway
if not self._gwateway.connected:
return
self._interface = self._gwateway._interface
self.info(f"Connect Sub Device {name} through gateway {host}")
else:
self.info("Trying to connect to %s...", host)
self._interface = await pytuya.connect(
self._dev_config_entry[CONF_HOST],
self._dev_config_entry[CONF_DEVICE_ID],
self._local_key,
float(self._dev_config_entry[CONF_PROTOCOL_VERSION]),
self._dev_config_entry.get(CONF_ENABLE_DEBUG, False),
self,
)
self._interface.add_dps_to_request(self.dps_to_request)
except Exception as ex: # pylint: disable=broad-except
self.warning(
f"Failed to connect to {self._dev_config_entry[CONF_HOST]}: %s", ex
)
if self._interface is not None:
await self._interface.close()
self._interface = None
self.warning(f"Failed to connect to {host}: {ex}")

await self.abort_connect()

if self._interface is not None:
try:
Expand All @@ -239,33 +271,33 @@ async def _make_connection(self):
self._interface.set_updatedps_list(self._default_reset_dpids)

# Reset the interface
await self._interface.reset(self._default_reset_dpids)
await self._interface.reset(
self._default_reset_dpids, cid=self._node_id
)

self.debug("Retrieving initial state")
status = await self._interface.status()

status = await self._interface.status(cid=self._node_id)
if status is None:
raise Exception("Failed to retrieve status")

self._interface.start_heartbeat()
if not self._node_id:
self._interface.start_heartbeat()
self.status_updated(status)

except UnicodeDecodeError as e: # pylint: disable=broad-except
self.exception(
f"Connect to {self._dev_config_entry[CONF_HOST]} failed: %s",
type(e),
)
if self._interface is not None:
await self._interface.close()
self._interface = None
self.exception(f"Connect to {host} failed: {type(e)}")

await self.abort_connect()

except Exception as e: # pylint: disable=broad-except
self.exception(f"Connect to {self._dev_config_entry[CONF_HOST]} failed")
if isinstance(e, ValueError):
self.warning(f"Connect to {name} failed: {e}")
else:
self.warning(f"Connect to {host} failed: {e}")
if "json.decode" in str(type(e)):
self.warning(f"Initial state update failed {e}, trying key update")
await self.update_local_key()

if self._interface is not None:
await self._interface.close()
self._interface = None
await self.abort_connect()

if self._interface is not None:
# Attempt to restore status for all entities that need to first set
Expand Down Expand Up @@ -297,10 +329,24 @@ def _new_entity_handler(entity_id):
)

self._is_closing = False
self.info(f"Successfully connected to {self._dev_config_entry[CONF_HOST]}")
self.info(f"Successfully connected to {name if self._node_id else host}")
if self._sub_devices:
connect_sub_devices = [
device.async_connect() for device in self._sub_devices.values()
]
await asyncio.gather(*connect_sub_devices)

self._connect_task = None

async def abort_connect(self):
"""Abort the connect process to the interface[device]"""
if self._node_id:
self._interface = None

if self._interface is not None:
await self._interface.close()
self._interface = None

async def update_local_key(self):
"""Retrieve updated local_key from Cloud API and update the config_entry."""
dev_id = self._dev_config_entry[CONF_DEVICE_ID]
Expand All @@ -321,7 +367,7 @@ async def update_local_key(self):
async def _async_refresh(self, _now):
if self._interface is not None:
self.debug("Refreshing dps for device")
await self._interface.update_dps()
await self._interface.update_dps(cid=self._node_id)

async def close(self):
"""Close connection and stop re-connect loop."""
Expand All @@ -344,7 +390,7 @@ async def set_dp(self, state, dp_index):
"""Change value of a DP of the Tuya device."""
if self._interface is not None:
try:
await self._interface.set_dp(state, dp_index)
await self._interface.set_dp(state, dp_index, cid=self._node_id)
except Exception: # pylint: disable=broad-except
self.debug("Failed to set DP %d to %s", dp_index, str(state))
else:
Expand All @@ -356,7 +402,7 @@ async def set_dps(self, states):
"""Change value of a DPs of the Tuya device."""
if self._interface is not None:
try:
await self._interface.set_dps(states)
await self._interface.set_dps(states, cid=self._node_id)
except Exception: # pylint: disable=broad-except
self.debug("Failed to set DPs %r", states)
else:
Expand All @@ -365,8 +411,9 @@ async def set_dps(self, states):
)

@callback
def status_updated(self, status):
def status_updated(self, status: dict):
"""Device updated status."""
status = status.get(self._node_id) if self._node_id else status.get("parent")
self._handle_event(self._status, status)
self._status.update(status)
self._dispatch_status()
Expand All @@ -375,12 +422,12 @@ def _dispatch_status(self):
signal = f"localtuya_{self._dev_config_entry[CONF_DEVICE_ID]}"
async_dispatcher_send(self._hass, signal, self._status)

def _handle_event(self, old_status, new_status):
def _handle_event(self, old_status, new_status, deviceID=None):
"""Handle events in HA when devices updated."""

def fire_event(event, data: dict):
event_data = {
CONF_DEVICE_ID: self._dev_config_entry[CONF_DEVICE_ID],
CONF_DEVICE_ID: deviceID or self._dev_config_entry[CONF_DEVICE_ID],
CONF_TYPE: event,
}
event_data.update(data)
Expand Down Expand Up @@ -421,15 +468,15 @@ def disconnected(self):
self._unsub_interval()
self._unsub_interval = None
self._interface = None

if self._connect_task is not None:
# self._connect_task.cancel()
self._connect_task = None
self.warning("Disconnected - waiting for discovery broadcast")
# If it's disconnect by unexpected error.
if self._is_closing is not True:
self._is_closing = True
self._hass.create_task(self.async_connect())
else:
self.warning("Disconnected - waiting for discovery broadcast")


class LocalTuyaEntity(RestoreEntity, pytuya.ContextualLogger):
Expand Down Expand Up @@ -588,7 +635,7 @@ def device_class(self):
def dps(self, dp_index):
"""Return cached value for DPS index."""
value = self._status.get(str(dp_index))
if value is None:
if value is None and not self._dev_config_entry.get(CONF_NODE_ID):
self.warning(
"Entity %s is requesting unknown DPS index %s",
self.entity_id,
Expand Down
Loading

0 comments on commit 99ddac0

Please sign in to comment.