diff --git a/custom_components/localtuya/__init__.py b/custom_components/localtuya/__init__.py index 60c898891..ee267dcee 100644 --- a/custom_components/localtuya/__init__.py +++ b/custom_components/localtuya/__init__.py @@ -39,6 +39,7 @@ DATA_DISCOVERY, DOMAIN, TUYA_DEVICES, + CONF_NODE_ID, ) from .discovery import TuyaDiscovery @@ -237,15 +238,25 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): res = await tuya_api.async_get_devices_list() hass.data[DOMAIN][entry.entry_id][DATA_CLOUD] = tuya_api - async def setup_entities(device_ids): + async def setup_entities(devices: dict): platforms = set() - for dev_id in device_ids: + for dev_id, config in devices.items(): + host = config.get(CONF_HOST) entities = entry.data[CONF_DEVICES][dev_id][CONF_ENTITIES] platforms = platforms.union( set(entity[CONF_PLATFORM] for entity in entities) ) - hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES][dev_id] = TuyaDevice( + if node_id := config.get(CONF_NODE_ID): + # Setup sub device as gateway if no gateway not exist. + if host not in hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES]: + hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES][host] = TuyaDevice( + hass, entry, dev_id, True + ) + + host = f"{host}_{node_id}" + + hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES][host] = TuyaDevice( hass, entry, dev_id ) @@ -257,16 +268,18 @@ async def setup_entities(device_ids): device.async_connect() for device in hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES].values() ] - try: - await asyncio.wait_for(asyncio.gather(*connect_task), 1) - except: - # If there is device that isn't connected to network it will return failed Initialization. - ... + await asyncio.wait_for(asyncio.gather(*connect_task), 5) + # await asyncio.gather(*connect_task) + # try: + # await asyncio.wait_for(asyncio.gather(*connect_task), 1) + # except: + # # If there is device that isn't connected to network it will return failed Initialization. + # ... + + await setup_entities(entry.data[CONF_DEVICES]) - await setup_entities(entry.data[CONF_DEVICES].keys()) # callback back to unsub listener unsub_listener = entry.add_update_listener(update_listener) - hass.data[DOMAIN][entry.entry_id].update({UNSUB_LISTENER: unsub_listener}) # Add reconnect trigger every 1mins to reconnect if device not connected. @@ -332,7 +345,8 @@ async def async_remove_config_entry_device( ) return True - await hass.data[DOMAIN][config_entry.entry_id][TUYA_DEVICES][dev_id].close() + host = config_entry.data[CONF_DEVICES][dev_id][CONF_HOST] + await hass.data[DOMAIN][config_entry.entry_id][TUYA_DEVICES][host].close() new_data = config_entry.data.copy() new_data[CONF_DEVICES].pop(dev_id) @@ -353,9 +367,9 @@ def reconnectTask(hass: HomeAssistant, entry: ConfigEntry): async def _async_reconnect(now): """Try connecting to devices not already connected to.""" - for devID, dev in hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES].items(): + for host, dev in hass.data[DOMAIN][entry.entry_id][TUYA_DEVICES].items(): if not dev.connected: - hass.create_task(dev.async_connect()) + hass.async_create_task(dev.async_connect()) hass.data[DOMAIN][entry.entry_id][RECONNECT_TASK] = async_track_time_interval( hass, _async_reconnect, RECONNECT_INTERVAL diff --git a/custom_components/localtuya/common.py b/custom_components/localtuya/common.py index 7228b4bae..da515dcd5 100644 --- a/custom_components/localtuya/common.py +++ b/custom_components/localtuya/common.py @@ -85,10 +85,9 @@ async def async_setup_entry( entities = [] for dev_id in config_entry.data[CONF_DEVICES]: - # entities_to_setup = prepare_setup_entities( - # hass, config_entry.data[dev_id], domain - # ) - dev_entry = config_entry.data[CONF_DEVICES][dev_id] + dev_entry: dict = config_entry.data[CONF_DEVICES][dev_id] + host = dev_entry.get(CONF_HOST) + entities_to_setup = [ entity for entity in dev_entry[CONF_ENTITIES] @@ -96,9 +95,10 @@ async def async_setup_entry( ] if entities_to_setup: - tuyainterface = hass.data[DOMAIN][config_entry.entry_id][TUYA_DEVICES][ - dev_id - ] + if node_id := dev_entry.get(CONF_NODE_ID): + host = f"{host}_{node_id}" + + tuyainterface = hass.data[DOMAIN][config_entry.entry_id][TUYA_DEVICES][host] dps_config_fields = list(get_dps_for_platform(flow_schema)) @@ -146,15 +146,25 @@ def async_config_entry_by_device_id(hass, device_id): class TuyaDevice(pytuya.TuyaListener, pytuya.ContextualLogger): - """Cache wrapper for pytuya.TuyaInterface.""" - - def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry, dev_id: str): + """Cache wrapper for pytuya.TuyaInterface, and Sub Devices.""" + + def __init__( + self, + hass: HomeAssistant, + config_entry: ConfigEntry, + dev_id: str, + gateway=False, + ): """Initialize the cache.""" super().__init__() self._hass = hass self._config_entry = config_entry self._dev_config_entry: dict = config_entry.data[CONF_DEVICES][dev_id].copy() self._interface = None + # For SubDevices + self._node_id: str = not gateway and self._dev_config_entry.get(CONF_NODE_ID) + self._gwateway: TuyaDevice = None + self._sub_devices = {} self._status = {} self.dps_to_request = {} self._is_closing = False @@ -195,6 +205,21 @@ def connected(self): """Return if connected to device.""" return self._interface is not None + def get_gateway(self): + """Search for gateway device if not exist create one""" + if not self._node_id: + return + entry_id = self._config_entry.entry_id + node_host = self._dev_config_entry.get(CONF_HOST) + devices: dict = self._hass.data[DOMAIN][entry_id][TUYA_DEVICES] + + # Sub to gateway. + gateway = devices.get(node_host) + gateway._sub_devices[self._node_id] = self + for dev_ip, device in devices.items(): + if dev_ip == node_host: + return device + async def async_connect(self): """Connect to device if not already connected.""" # self.info("async_connect: %d %r %r", self._is_closing, self._connect_task, self._interface) @@ -204,26 +229,33 @@ async def async_connect(self): async def _make_connection(self): """Subscribe localtuya entity events.""" - self.info("Trying to connect to %s...", self._dev_config_entry[CONF_HOST]) self._connect_task = True + host = self._dev_config_entry.get(CONF_HOST) + name = self._dev_config_entry.get(CONF_FRIENDLY_NAME) + try: - self._interface = await pytuya.connect( - self._dev_config_entry[CONF_HOST], - self._dev_config_entry[CONF_DEVICE_ID], - self._local_key, - float(self._dev_config_entry[CONF_PROTOCOL_VERSION]), - self._dev_config_entry.get(CONF_ENABLE_DEBUG, False), - self._dev_config_entry.get(CONF_NODE_ID, None), - self, - ) + if self._node_id: + gateway = self._gwateway + self._gwateway = self.get_gateway() if not gateway else gateway + if not self._gwateway.connected: + return + self._interface = self._gwateway._interface + self.info(f"Connect Sub Device {name} through gateway {host}") + else: + self.info("Trying to connect to %s...", host) + self._interface = await pytuya.connect( + self._dev_config_entry[CONF_HOST], + self._dev_config_entry[CONF_DEVICE_ID], + self._local_key, + float(self._dev_config_entry[CONF_PROTOCOL_VERSION]), + self._dev_config_entry.get(CONF_ENABLE_DEBUG, False), + self, + ) self._interface.add_dps_to_request(self.dps_to_request) except Exception as ex: # pylint: disable=broad-except - self.warning( - f"Failed to connect to {self._dev_config_entry[CONF_HOST]}: %s", ex - ) - if self._interface is not None: - await self._interface.close() - self._interface = None + self.warning(f"Failed to connect to {host}: {ex}") + + await self.abort_connect() if self._interface is not None: try: @@ -239,33 +271,33 @@ async def _make_connection(self): self._interface.set_updatedps_list(self._default_reset_dpids) # Reset the interface - await self._interface.reset(self._default_reset_dpids) + await self._interface.reset( + self._default_reset_dpids, cid=self._node_id + ) self.debug("Retrieving initial state") - status = await self._interface.status() + + status = await self._interface.status(cid=self._node_id) if status is None: raise Exception("Failed to retrieve status") - - self._interface.start_heartbeat() + if not self._node_id: + self._interface.start_heartbeat() self.status_updated(status) except UnicodeDecodeError as e: # pylint: disable=broad-except - self.exception( - f"Connect to {self._dev_config_entry[CONF_HOST]} failed: %s", - type(e), - ) - if self._interface is not None: - await self._interface.close() - self._interface = None + self.exception(f"Connect to {host} failed: {type(e)}") + + await self.abort_connect() except Exception as e: # pylint: disable=broad-except - self.exception(f"Connect to {self._dev_config_entry[CONF_HOST]} failed") + if isinstance(e, ValueError): + self.warning(f"Connect to {name} failed: {e}") + else: + self.warning(f"Connect to {host} failed: {e}") if "json.decode" in str(type(e)): + self.warning(f"Initial state update failed {e}, trying key update") await self.update_local_key() - - if self._interface is not None: - await self._interface.close() - self._interface = None + await self.abort_connect() if self._interface is not None: # Attempt to restore status for all entities that need to first set @@ -297,10 +329,24 @@ def _new_entity_handler(entity_id): ) self._is_closing = False - self.info(f"Successfully connected to {self._dev_config_entry[CONF_HOST]}") + self.info(f"Successfully connected to {name if self._node_id else host}") + if self._sub_devices: + connect_sub_devices = [ + device.async_connect() for device in self._sub_devices.values() + ] + await asyncio.gather(*connect_sub_devices) self._connect_task = None + async def abort_connect(self): + """Abort the connect process to the interface[device]""" + if self._node_id: + self._interface = None + + if self._interface is not None: + await self._interface.close() + self._interface = None + async def update_local_key(self): """Retrieve updated local_key from Cloud API and update the config_entry.""" dev_id = self._dev_config_entry[CONF_DEVICE_ID] @@ -321,7 +367,7 @@ async def update_local_key(self): async def _async_refresh(self, _now): if self._interface is not None: self.debug("Refreshing dps for device") - await self._interface.update_dps() + await self._interface.update_dps(cid=self._node_id) async def close(self): """Close connection and stop re-connect loop.""" @@ -344,7 +390,7 @@ async def set_dp(self, state, dp_index): """Change value of a DP of the Tuya device.""" if self._interface is not None: try: - await self._interface.set_dp(state, dp_index) + await self._interface.set_dp(state, dp_index, cid=self._node_id) except Exception: # pylint: disable=broad-except self.debug("Failed to set DP %d to %s", dp_index, str(state)) else: @@ -356,7 +402,7 @@ async def set_dps(self, states): """Change value of a DPs of the Tuya device.""" if self._interface is not None: try: - await self._interface.set_dps(states) + await self._interface.set_dps(states, cid=self._node_id) except Exception: # pylint: disable=broad-except self.debug("Failed to set DPs %r", states) else: @@ -365,8 +411,9 @@ async def set_dps(self, states): ) @callback - def status_updated(self, status): + def status_updated(self, status: dict): """Device updated status.""" + status = status.get(self._node_id) if self._node_id else status.get("parent") self._handle_event(self._status, status) self._status.update(status) self._dispatch_status() @@ -375,12 +422,12 @@ def _dispatch_status(self): signal = f"localtuya_{self._dev_config_entry[CONF_DEVICE_ID]}" async_dispatcher_send(self._hass, signal, self._status) - def _handle_event(self, old_status, new_status): + def _handle_event(self, old_status, new_status, deviceID=None): """Handle events in HA when devices updated.""" def fire_event(event, data: dict): event_data = { - CONF_DEVICE_ID: self._dev_config_entry[CONF_DEVICE_ID], + CONF_DEVICE_ID: deviceID or self._dev_config_entry[CONF_DEVICE_ID], CONF_TYPE: event, } event_data.update(data) @@ -421,15 +468,15 @@ def disconnected(self): self._unsub_interval() self._unsub_interval = None self._interface = None - if self._connect_task is not None: # self._connect_task.cancel() self._connect_task = None - self.warning("Disconnected - waiting for discovery broadcast") # If it's disconnect by unexpected error. if self._is_closing is not True: self._is_closing = True self._hass.create_task(self.async_connect()) + else: + self.warning("Disconnected - waiting for discovery broadcast") class LocalTuyaEntity(RestoreEntity, pytuya.ContextualLogger): @@ -588,7 +635,7 @@ def device_class(self): def dps(self, dp_index): """Return cached value for DPS index.""" value = self._status.get(str(dp_index)) - if value is None: + if value is None and not self._dev_config_entry.get(CONF_NODE_ID): self.warning( "Entity %s is requesting unknown DPS index %s", self.entity_id, diff --git a/custom_components/localtuya/config_flow.py b/custom_components/localtuya/config_flow.py index f7507f342..f61726aa2 100644 --- a/custom_components/localtuya/config_flow.py +++ b/custom_components/localtuya/config_flow.py @@ -61,6 +61,7 @@ ENTITY_CATEGORY, DEFAULT_CATEGORIES, SUPPORTED_PROTOCOL_VERSIONS, + TUYA_DEVICES, ) from .discovery import discover @@ -336,40 +337,47 @@ async def validate_input(hass: core.HomeAssistant, entry_id, data): detected_dps = {} error = None interface = None - reset_ids = None + close = True + + cid = data.get(CONF_NODE_ID, None) + localtuya_devices = hass.data[DOMAIN][entry_id][TUYA_DEVICES] try: conf_protocol = data[CONF_PROTOCOL_VERSION] auto_protocol = conf_protocol == "auto" - # If 'auto' will be loop through supported protocols. - for ver in SUPPORTED_PROTOCOL_VERSIONS: - try: - version = ver if auto_protocol else conf_protocol - interface = await pytuya.connect( - data[CONF_HOST], - data[CONF_DEVICE_ID], - data[CONF_LOCAL_KEY], - float(version), - data[CONF_ENABLE_DEBUG], - data.get(CONF_NODE_ID, None), - ) - - # Break the loop if input isn't auto. - if not auto_protocol: - break + # If sub device we will search if gateway is existed if not create new connection. + if cid and (existed_interface := localtuya_devices.get(data[CONF_HOST])): + interface = existed_interface._interface + close = False + else: + # If 'auto' will be loop through supported protocols. + for ver in SUPPORTED_PROTOCOL_VERSIONS: + try: + version = ver if auto_protocol else conf_protocol + interface = await pytuya.connect( + data[CONF_HOST], + data[CONF_DEVICE_ID], + data[CONF_LOCAL_KEY], + float(version), + data[CONF_ENABLE_DEBUG], + ) - detected_dps = await interface.detect_available_dps() - # If Auto: using DPS detected we will assume this is the correct version if dps found. - if len(detected_dps) > 0: - # Set the conf_protocol to the worked version to return it and update self.device_data. - conf_protocol = version - break - # If connection to host is failed raise wrong address. - except OSError as ex: - if ex.errno == errno.EHOSTUNREACH: - raise CannotConnect - except: - continue + # Break the loop if input isn't auto. + if not auto_protocol: + break + + detected_dps = await interface.detect_available_dps(cid=cid) + # If Auto: using DPS detected we will assume this is the correct version if dps found. + if len(detected_dps) > 0: + # Set the conf_protocol to the worked version to return it and update self.device_data. + conf_protocol = version + break + # If connection to host is failed raise wrong address. + except OSError as ex: + if ex.errno == errno.EHOSTUNREACH: + raise CannotConnect + except: + continue if CONF_RESET_DPIDS in data: reset_ids_str = data[CONF_RESET_DPIDS].split(",") @@ -392,14 +400,15 @@ async def validate_input(hass: core.HomeAssistant, entry_id, data): interface.set_updatedps_list(reset_ids) # Reset the interface - await interface.reset(reset_ids) + await interface.reset(reset_ids, cid=cid) # Detect any other non-manual DPS strings - detected_dps = await interface.detect_available_dps() + + detected_dps = await interface.detect_available_dps(cid=cid) except ValueError as ex: error = ex except Exception as ex: # pylint: disable=broad-except - _LOGGER.debug("No DPS able to be detected") + _LOGGER.debug(f"No DPS able to be detected {ex}") detected_dps = {} # if manual DPs are set, merge these. @@ -421,7 +430,7 @@ async def validate_input(hass: core.HomeAssistant, entry_id, data): except ValueError as ex: raise InvalidAuth from ex finally: - if interface: + if interface and close: await interface.close() # Indicate an error if no datapoints found as the rest of the flow diff --git a/custom_components/localtuya/core/pytuya/__init__.py b/custom_components/localtuya/core/pytuya/__init__.py index 64edd3606..a90371364 100644 --- a/custom_components/localtuya/core/pytuya/__init__.py +++ b/custom_components/localtuya/core/pytuya/__init__.py @@ -738,7 +738,6 @@ def __init__( self, dev_id, local_key, - node_id, protocol_version, enable_debug, on_connected, @@ -759,7 +758,6 @@ def __init__( self.loop = asyncio.get_running_loop() self.set_logger(_LOGGER, dev_id, enable_debug) self.id = dev_id - self.node_id = node_id self.local_key = local_key.encode("latin1") self.real_local_key = self.local_key self.dev_type = "type_0a" @@ -817,14 +815,17 @@ def _status_update(msg): self.seqno = msg.seqno + 1 decoded_message: dict = self._decode_payload(msg.payload) - if "cid" in decoded_message and decoded_message["cid"] != self.node_id: - return - if "dps" in decoded_message: - self.dps_cache.update(decoded_message["dps"]) + if cid := decoded_message.get("cid"): + self.dps_cache.update({cid: decoded_message["dps"]}) + else: + self.dps_cache.update({"parent": decoded_message["dps"]}) listener = self.listener and self.listener() if listener is not None: + if cid: + listener = listener._sub_devices.get(cid, listener) + listener.status_updated(self.dps_cache) return MessageDispatcher( @@ -940,7 +941,7 @@ async def exchange_quick(self, payload, recv_retries): ) return None - async def exchange(self, command, dps=None): + async def exchange(self, command, dps=None, nodeID=None): """Send and receive a message, returning response from device.""" if self.version >= 3.4 and self.real_local_key == self.local_key: self.debug("3.4 or 3.5 device: negotiating a new session key") @@ -951,7 +952,7 @@ async def exchange(self, command, dps=None): command, self.dev_type, ) - payload = self._generate_payload(command, dps) + payload = self._generate_payload(command, dps, nodeId=nodeID) real_cmd = payload.cmd dev_type = self.dev_type # self.debug("Exchange: payload %r %r", payload.cmd, payload.payload) @@ -987,27 +988,30 @@ async def exchange(self, command, dps=None): dev_type, self.dev_type, ) - return await self.exchange(command, dps) + return await self.exchange(command, dps, nodeID=nodeID) return payload - async def status(self): + async def status(self, cid=None): """Return device status.""" - status = await self.exchange(DP_QUERY) + status: dict = await self.exchange(command=DP_QUERY, nodeID=cid) + + if cid and status and "dps" in status: + self.dps_cache.update({cid: status["dps"]}) + elif status and "dps" in status: + self.dps_cache.update({"parent": status["dps"]}) - if status and "dps" in status: - self.dps_cache.update(status["dps"]) return self.dps_cache async def heartbeat(self): """Send a heartbeat message.""" return await self.exchange(HEART_BEAT) - async def reset(self, dpIds=None): + async def reset(self, dpIds=None, cid=None): """Send a reset message (3.3 only).""" if self.version == 3.3: self.dev_type = "type_0a" self.debug("reset switching to dev_type %s", self.dev_type) - return await self.exchange(UPDATEDPS, dpIds) + return await self.exchange(UPDATEDPS, dpIds, nodeID=cid) return True @@ -1015,7 +1019,7 @@ def set_updatedps_list(self, update_list): """Set the DPS to be requested with the update command.""" self.dps_whitelist = update_list - async def update_dps(self, dps=None): + async def update_dps(self, dps=None, cid=None): """ Request device to update index. @@ -1025,18 +1029,20 @@ async def update_dps(self, dps=None): if self.version in UPDATE_DPS_LIST: if dps is None: if not self.dps_cache: - await self.detect_available_dps() + await self.detect_available_dps(cid=cid) if self.dps_cache: - dps = [int(dp) for dp in self.dps_cache] + if cid and cid in self.dps_cache: + dps = [int(dp) for dp in self.dps_cache[cid]] + else: + dps = [int(dp) for dp in self.dps_cache["parent"]] # filter non whitelisted dps dps = list(set(dps).intersection(set(self.dps_whitelist))) - self.debug("updatedps() entry (dps %s, dps_cache %s)", dps, self.dps_cache) - payload = self._generate_payload(UPDATEDPS, dps) + payload = self._generate_payload(UPDATEDPS, dps, nodeId=cid) enc_payload = self._encode_message(payload) self.transport.write(enc_payload) return True - async def set_dp(self, value, dp_index): + async def set_dp(self, value, dp_index, cid=None): """ Set value (may be any type: bool, int or string) of any dps index. @@ -1044,13 +1050,13 @@ async def set_dp(self, value, dp_index): dp_index(int): dps index to set value: new value for the dps index """ - return await self.exchange(CONTROL, {str(dp_index): value}) + return await self.exchange(CONTROL, {str(dp_index): value}, nodeID=cid) - async def set_dps(self, dps): + async def set_dps(self, dps, cid=None): """Set values for a set of datapoints.""" - return await self.exchange(CONTROL, dps) + return await self.exchange(CONTROL, dps, nodeID=cid) - async def detect_available_dps(self): + async def detect_available_dps(self, cid=None): """Return which datapoints are supported by the device.""" # type_0d devices need a sort of bruteforce querying in order to detect the # list of available dps experience shows that the dps available are usually @@ -1066,18 +1072,17 @@ async def detect_available_dps(self): self.dps_to_request = {"1": None} self.add_dps_to_request(range(*dps_range)) try: - data = await self.status() + data = await self.status(cid=cid) except Exception as ex: self.exception("Failed to get status: %s", ex) raise if "dps" in data: - self.dps_cache.update(data["dps"]) + self.dps_cache.update({"parent": data["dps"]}) if self.dev_type == "type_0a": return self.dps_cache - self.debug("Detected dps: %s", self.dps_cache) self.dps_to_request = self.dps_cache - return self.dps_cache + return self.dps_cache.get(cid) or self.dps_cache.get("parent") def add_dps_to_request(self, dp_indicies): """Add a datapoint (DP) to be included in requests.""" @@ -1162,10 +1167,7 @@ def _decode_payload(self, payload): if len(payload) == 0: # No respones probably worng Local_Key raise ValueError("Connected but no respones localkey is incorrect?") if "devid not" in payload: # DeviceID Not found. - if self.node_id: - raise ValueError("Node_ID is incorrect!") - else: - raise ValueError("DeviceID Not found") + raise ValueError("DeviceID Not found") else: raise DecodeError( "could not decrypt data: wrong local_key? (exception: %s)" % ex @@ -1384,7 +1386,7 @@ def _generate_payload( else: json_data["uid"] = self.id if "cid" in json_data: - if cid := nodeId or self.node_id: + if cid := nodeId: json_data["cid"] = cid # for <= 3.3 we don't need `gwID`, `devID` and `uid` in payload. for k in ["gwId", "devId", "uid"]: @@ -1394,7 +1396,7 @@ def _generate_payload( del json_data["cid"] if "data" in json_data and "cid" in json_data["data"]: # "cid" is inside "data" For 3.4 and 3.5 versions. - if cid := nodeId or self.node_id: + if cid := nodeId: json_data["data"]["cid"] = cid else: del json_data["data"]["cid"] @@ -1435,7 +1437,6 @@ async def connect( local_key, protocol_version, enable_debug, - node_id=None, listener=None, port=6668, timeout=5, @@ -1447,7 +1448,6 @@ async def connect( lambda: TuyaProtocol( device_id, local_key, - node_id, protocol_version, enable_debug, on_connected,