From 4cb1f7a270ed31ba2f7dfb34f4a9b9339f7318ed Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Mon, 15 Nov 2021 15:18:20 +0000 Subject: [PATCH 01/11] utils.misc: Make nullcontext work with asyncio Implement __aenter__ and __aexit__ on nullcontext so it can be used as an asynchronous context manager. --- devlib/utils/misc.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/devlib/utils/misc.py b/devlib/utils/misc.py index de4944b3c..eb9dfd375 100644 --- a/devlib/utils/misc.py +++ b/devlib/utils/misc.py @@ -748,8 +748,7 @@ def batch_contextmanager(f, kwargs_list): yield -@contextmanager -def nullcontext(enter_result=None): +class nullcontext: """ Backport of Python 3.7 ``contextlib.nullcontext`` @@ -761,7 +760,20 @@ def nullcontext(enter_result=None): statement, or `None` if nothing is specified. :type enter_result: object """ - yield enter_result + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + async def __aenter__(self): + return self.enter_result + + def __exit__(*_): + return + + async def __aexit__(*_): + return class tls_property: From 6dc9eba386558fea447f92b8bf64474b39aab814 Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Thu, 7 Apr 2022 16:54:42 +0100 Subject: [PATCH 02/11] target: Fix Target.get_connection()'s busybox The conncetion returned by Target.get_connection() does not have its .busybox attribute initialized. This is expected for the first connection, but connections created for new threads should have busybox set. --- devlib/target.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/devlib/target.py b/devlib/target.py index 8fb5ce568..280005d63 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -381,7 +381,10 @@ def disconnect(self): def get_connection(self, timeout=None): if self.conn_cls is None: raise ValueError('Connection class not specified on Target creation.') - return self.conn_cls(timeout=timeout, **self.connection_settings) # pylint: disable=not-callable + conn = self.conn_cls(timeout=timeout, **self.connection_settings) # pylint: disable=not-callable + # This allows forwarding the detected busybox for connections created in new threads. + conn.busybox = self.busybox + return conn def wait_boot_complete(self, timeout=10): raise NotImplementedError() From 44cf2ef57e815478cd2d0aff89aad75f21217f64 Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Fri, 12 Nov 2021 11:19:04 +0000 Subject: [PATCH 03/11] target: Make __getstate__ more future-proof Remove all the tls_property from the state, as they will be recreated automatically. --- devlib/target.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/devlib/target.py b/devlib/target.py index 280005d63..1e4ceae7f 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -31,6 +31,7 @@ import uuid import xml.dom.minidom import copy +import inspect from collections import namedtuple, defaultdict from contextlib import contextmanager from pipes import quote @@ -56,7 +57,7 @@ from devlib.utils.misc import memoized, isiterable, convert_new_lines, groupby_value from devlib.utils.misc import commonprefix, merge_lists from devlib.utils.misc import ABI_MAP, get_cpu_name, ranges_to_list -from devlib.utils.misc import batch_contextmanager, tls_property, nullcontext +from devlib.utils.misc import batch_contextmanager, tls_property, _BoundTLSProperty, nullcontext from devlib.utils.types import integer, boolean, bitmask, identifier, caseless_string, bytes_regex @@ -337,12 +338,18 @@ def __init__(self, self.connect() def __getstate__(self): + # tls_property will recreate the underlying value automatically upon + # access and is typically used for dynamic content that cannot be + # pickled or should not transmitted to another thread. + ignored = { + k + for k, v in inspect.getmembers(self.__class__) + if isinstance(v, _BoundTLSProperty) + } return { k: v for k, v in self.__dict__.items() - # Avoid sharing the connection instance with the original target, - # so that each target can live its own independent life - if k != '_conn' + if k not in ignored } # connection and initialization From eba0039a9b91e7f78c12989967e1e85813f4a6d3 Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Thu, 11 Nov 2021 23:15:31 +0000 Subject: [PATCH 04/11] setup.py: cleanup dependencies in setup.py Remove dependencies that are ruled out due to the current Python minimal version requirement. --- setup.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 0365683e6..2cd427e02 100644 --- a/setup.py +++ b/setup.py @@ -90,12 +90,8 @@ 'scp', # SSH connection file transfers 'wrapt', # Basic for construction of decorator functions 'future', # Python 2-3 compatibility - 'enum34;python_version<"3.4"', # Enums for Python < 3.4 - 'contextlib2;python_version<"3.0"', # Python 3 contextlib backport for Python 2 - 'numpy<=1.16.4; python_version<"3"', - 'numpy; python_version>="3"', - 'pandas<=0.24.2; python_version<"3"', - 'pandas; python_version>"3"', + 'numpy', + 'pandas', 'lxml', # More robust xml parsing ], extras_require={ From 683413c0a3763ebd8200e395b01c2cce85376fa0 Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Thu, 11 Nov 2021 23:16:32 +0000 Subject: [PATCH 05/11] setup.py: Require Python >= 3.7 Require Python >= 3.7 in order to have access to a fully fledged asyncio module. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 2cd427e02..88d50a970 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ url='https://github.com/ARM-software/devlib', license='Apache v2', maintainer='ARM Ltd.', + python_requires='>= 3.7', install_requires=[ 'python-dateutil', # converting between UTC and local time. 'pexpect>=3.3', # Send/recieve to/from device From 1c43096931ae80be903e31d48f400ddb61b504ee Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Wed, 18 Aug 2021 10:35:36 +0100 Subject: [PATCH 06/11] utils/async: Add new utils.async module Home for async-related utilities. --- devlib/utils/asyn.py | 356 +++++++++++++++++++++++++++++++++++++++++++ setup.py | 1 + 2 files changed, 357 insertions(+) create mode 100644 devlib/utils/asyn.py diff --git a/devlib/utils/asyn.py b/devlib/utils/asyn.py new file mode 100644 index 000000000..ddda1158f --- /dev/null +++ b/devlib/utils/asyn.py @@ -0,0 +1,356 @@ +# Copyright 2013-2018 ARM Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +""" +Async-related utilities +""" + +import abc +import asyncio +import functools +import itertools +import contextlib +import pathlib +import os.path + +# Allow nesting asyncio loops, which is necessary for: +# * Being able to call the blocking variant of a function from an async +# function for backward compat +# * Critically, run the blocking variant of a function in a Jupyter notebook +# environment, since it also uses asyncio. +# +# Maybe there is still hope for future versions of Python though: +# https://bugs.python.org/issue22239 +import nest_asyncio +nest_asyncio.apply() + + +def create_task(awaitable, name=None): + if isinstance(awaitable, asyncio.Task): + task = awaitable + else: + task = asyncio.create_task(awaitable) + if name is None: + name = getattr(awaitable, '__qualname__', None) + task.name = name + return task + + +class AsyncManager: + def __init__(self): + self.task_tree = dict() + self.resources = dict() + + def track_access(self, access): + """ + Register the given ``access`` to have been handled by the current + async task. + + :param access: Access that were done. + :type access: ConcurrentAccessBase + + This allows :func:`concurrently` to check that concurrent tasks did not + step on each other's toes. + """ + try: + task = asyncio.current_task() + except RuntimeError: + pass + else: + self.resources.setdefault(task, set()).add(access) + + async def concurrently(self, awaitables): + """ + Await concurrently for the given awaitables, and cancel them as soon as + one raises an exception. + """ + awaitables = list(awaitables) + + # Avoid creating asyncio.Tasks when it's not necessary, as it will + # disable a the blocking path optimization of Target._execute_async() + # that uses blocking calls as long as there is only one asyncio.Task + # running on the event loop. + if len(awaitables) == 1: + return [await awaitables[0]] + + tasks = list(map(create_task, awaitables)) + + current_task = asyncio.current_task() + task_tree = self.task_tree + + try: + node = task_tree[current_task] + except KeyError: + is_root_task = True + node = set() + else: + is_root_task = False + task_tree[current_task] = node + + task_tree.update({ + child: set() + for child in tasks + }) + node.update(tasks) + + try: + return await asyncio.gather(*tasks) + except BaseException: + for task in tasks: + task.cancel() + raise + finally: + + def get_children(task): + immediate_children = task_tree[task] + return frozenset( + itertools.chain( + [task], + immediate_children, + itertools.chain.from_iterable( + map(get_children, immediate_children) + ) + ) + ) + + # Get the resources created during the execution of each subtask + # (directly or indirectly) + resources = { + task: frozenset( + itertools.chain.from_iterable( + self.resources.get(child, []) + for child in get_children(task) + ) + ) + for task in tasks + } + for (task1, resources1), (task2, resources2) in itertools.combinations(resources.items(), 2): + for res1, res2 in itertools.product(resources1, resources2): + if issubclass(res2.__class__, res1.__class__) and res1.overlap_with(res2): + raise RuntimeError( + 'Overlapping resources manipulated in concurrent async tasks: {} (task {}) and {} (task {})'.format(res1, task1.name, res2, task2.name) + ) + + if is_root_task: + self.resources.clear() + task_tree.clear() + + async def map_concurrently(self, f, keys): + """ + Similar to :meth:`concurrently`, + but maps the given function ``f`` on the given ``keys``. + + :return: A dictionary with ``keys`` as keys, and function result as + values. + """ + keys = list(keys) + return dict(zip( + keys, + await self.concurrently(map(f, keys)) + )) + + +def compose(*coros): + """ + Compose coroutines, feeding the output of each as the input of the next + one. + + ``await compose(f, g)(x)`` is equivalent to ``await f(await g(x))`` + + .. note:: In Haskell, ``compose f g h`` would be equivalent to ``f <=< g <=< h`` + """ + async def f(*args, **kwargs): + empty_dict = {} + for coro in reversed(coros): + x = coro(*args, **kwargs) + # Allow mixing corountines and regular functions + if asyncio.isfuture(x): + x = await x + args = [x] + kwargs = empty_dict + + return x + return f + + +class _AsyncPolymorphicFunction: + """ + A callable that allows exposing both a synchronous and asynchronous API. + + When called, the blocking synchronous operation is called. The ```asyn`` + attribute gives access to the asynchronous version of the function, and all + the other attribute access will be redirected to the async function. + """ + def __init__(self, asyn, blocking): + self.asyn = asyn + self.blocking = blocking + + def __get__(self, *args, **kwargs): + return self.__class__( + asyn=self.asyn.__get__(*args, **kwargs), + blocking=self.blocking.__get__(*args, **kwargs), + ) + + def __call__(self, *args, **kwargs): + return self.blocking(*args, **kwargs) + + def __getattr__(self, attr): + return getattr(self.asyn, attr) + + +def asyncf(f): + """ + Decorator used to turn a coroutine into a blocking function, with an + optional asynchronous API. + + **Example**:: + + @asyncf + async def foo(x): + await do_some_async_things(x) + return x + + # Blocking call, just as if the function was synchronous, except it may + # use asynchronous code inside, e.g. to do concurrent operations. + foo(42) + + # Asynchronous API, foo.asyn being a corountine + await foo.asyn(42) + + This allows the same implementation to be both used as blocking for ease of + use and backward compatibility, or exposed as a corountine for callers that + can deal with awaitables. + """ + @functools.wraps(f) + def blocking(*args, **kwargs): + # Since run() needs a corountine, make sure we provide one + async def wrapper(): + x = f(*args, **kwargs) + # Async generators have to be consumed and accumulated in a list + # before crossing a blocking boundary. + if inspect.isasyncgen(x): + + def genf(): + asyncgen = x.__aiter__() + while True: + try: + yield asyncio.run(asyncgen.__anext__()) + except StopAsyncIteration: + return + + return genf() + else: + return await x + return asyncio.run(wrapper()) + + return _AsyncPolymorphicFunction( + asyn=f, + blocking=blocking, + ) + + +class _AsyncPolymorphicCM: + """ + Wrap an async context manager such that it exposes a synchronous API as + well for backward compatibility. + """ + def __init__(self, async_cm): + self.cm = async_cm + + def __aenter__(self, *args, **kwargs): + return self.cm.__aenter__(*args, **kwargs) + + def __aexit__(self, *args, **kwargs): + return self.cm.__aexit__(*args, **kwargs) + + def __enter__(self, *args, **kwargs): + return asyncio.run(self.cm.__aenter__(*args, **kwargs)) + + def __exit__(self, *args, **kwargs): + return asyncio.run(self.cm.__aexit__(*args, **kwargs)) + + +def asynccontextmanager(f): + """ + Same as :func:`contextlib.asynccontextmanager` except that it can also be + used with a regular ``with`` statement for backward compatibility. + """ + f = contextlib.asynccontextmanager(f) + + @functools.wraps(f) + def wrapper(*args, **kwargs): + cm = f(*args, **kwargs) + return _AsyncPolymorphicCM(cm) + + return wrapper + + +class ConcurrentAccessBase(abc.ABC): + """ + Abstract Base Class for resources tracked by :func:`concurrently`. + """ + @abc.abstractmethod + def overlap_with(self, other): + """ + Return ``True`` if the resource overlaps with the given one. + + :param other: Resources that should not overlap with ``self``. + :type other: devlib.utils.asym.ConcurrentAccessBase + + .. note:: It is guaranteed that ``other`` will be a subclass of our + class. + """ + +class PathAccess(ConcurrentAccessBase): + """ + Concurrent resource representing a file access. + + :param namespace: Identifier of the namespace of the path. One of "target" or "host". + :type namespace: str + + :param path: Normalized path to the file. + :type path: str + + :param mode: Opening mode of the file. Can be ``"r"`` for read and ``"w"`` + for writing. + :type mode: str + """ + def __init__(self, namespace, path, mode): + assert namespace in ('host', 'target') + self.namespace = namespace + assert mode in ('r', 'w') + self.mode = mode + self.path = os.path.abspath(path) if namespace == 'host' else os.path.normpath(path) + + def overlap_with(self, other): + path1 = pathlib.Path(self.path).resolve() + path2 = pathlib.Path(other.path).resolve() + return ( + self.namespace == other.namespace and + 'w' in (self.mode, other.mode) and + ( + path1 == path2 or + path1 in path2.parents or + path2 in path1.parents + ) + ) + + def __str__(self): + mode = { + 'r': 'read', + 'w': 'write', + }[self.mode] + return '{} ({})'.format(self.path, mode) diff --git a/setup.py b/setup.py index 88d50a970..1bb7bd104 100644 --- a/setup.py +++ b/setup.py @@ -94,6 +94,7 @@ 'numpy', 'pandas', 'lxml', # More robust xml parsing + 'nest_asyncio', # Allows running nested asyncio loops ], extras_require={ 'daq': ['daqpower>=2'], From 276c8e5204db2d012b9df99c4178dd999d717ea0 Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Mon, 15 Nov 2021 14:46:28 +0000 Subject: [PATCH 07/11] target: Enable async methods Add async variants of Target methods. --- devlib/target.py | 848 ++++++++++++++++++++++++++++++----------------- 1 file changed, 550 insertions(+), 298 deletions(-) diff --git a/devlib/target.py b/devlib/target.py index 1e4ceae7f..cd4990057 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -13,6 +13,7 @@ # limitations under the License. # +import asyncio import io import base64 import functools @@ -32,6 +33,7 @@ import xml.dom.minidom import copy import inspect +import itertools from collections import namedtuple, defaultdict from contextlib import contextmanager from pipes import quote @@ -44,6 +46,7 @@ from collections import Mapping from enum import Enum +from concurrent.futures import ThreadPoolExecutor from devlib.host import LocalConnection, PACKAGE_BIN_DIRECTORY from devlib.module import get_module @@ -58,7 +61,9 @@ from devlib.utils.misc import commonprefix, merge_lists from devlib.utils.misc import ABI_MAP, get_cpu_name, ranges_to_list from devlib.utils.misc import batch_contextmanager, tls_property, _BoundTLSProperty, nullcontext +from devlib.utils.misc import strip_bash_colors from devlib.utils.types import integer, boolean, bitmask, identifier, caseless_string, bytes_regex +import devlib.utils.asyn as asyn FSTAB_ENTRY_REGEX = re.compile(r'(\S+) on (.+) type (\S+) \((\S+)\)') @@ -276,11 +281,21 @@ def shutils(self): @tls_property def _conn(self): - return self.get_connection() + try: + return self._unused_conns.pop() + except KeyError: + return self.get_connection() # Add a basic property that does not require calling to get the value conn = _conn.basic_property + @tls_property + def _async_manager(self): + return asyn.AsyncManager() + + # Add a basic property that does not require calling to get the value + async_manager = _async_manager.basic_property + def __init__(self, connection_settings=None, platform=None, @@ -293,6 +308,8 @@ def __init__(self, conn_cls=None, is_container=False ): + self._async_pool = None + self._unused_conns = set() self._is_rooted = None self.connection_settings = connection_settings or {} @@ -354,7 +371,8 @@ def __getstate__(self): # connection and initialization - def connect(self, timeout=None, check_boot_completed=True): + @asyn.asyncf + async def connect(self, timeout=None, check_boot_completed=True, max_async=50): self.platform.init_target_connection(self) # Forcefully set the thread-local value for the connection, with the # timeout we want @@ -367,23 +385,71 @@ def connect(self, timeout=None, check_boot_completed=True): self.execute('mkdir -p {}'.format(quote(self.executables_directory))) self.busybox = self.install(os.path.join(PACKAGE_BIN_DIRECTORY, self.abi, 'busybox'), timeout=30) self.conn.busybox = self.busybox + self._detect_max_async(max_async) self.platform.update_from_target(self) self._update_modules('connected') if self.platform.big_core and self.load_default_modules: self._install_module(get_module('bl')) - def check_connection(self): + def _detect_max_async(self, max_async): + self.logger.debug('Detecting max number of async commands ...') + + def make_conn(_): + try: + conn = self.get_connection() + except Exception: + return None + else: + payload = 'hello' + # Sanity check the connection, in case we managed to connect + # but it's actually unusable. + try: + res = conn.execute(f'echo {quote(payload)}') + except Exception: + return None + else: + if res.strip() == payload: + return conn + else: + return None + + # Logging needs to be disabled before the thread pool is created, + # otherwise the logging config will not be taken into account + logging.disable() + try: + # Aggressively attempt to create all the connections in parallel, + # so that this setup step does not take too much time. + with ThreadPoolExecutor(max_async) as pool: + conns = pool.map(make_conn, range(max_async)) + # Avoid polluting the log with errors coming from broken + # connections. + finally: + logging.disable(logging.NOTSET) + + conns = {conn for conn in conns if conn is not None} + + # Keep the connection so it can be reused by future threads + self._unused_conns.update(conns) + max_conns = len(conns) + + self.logger.debug(f'Detected max number of async commands: {max_conns}') + self._async_pool = ThreadPoolExecutor(max_conns) + + @asyn.asyncf + async def check_connection(self): """ Check that the connection works without obvious issues. """ - out = self.execute('true', as_root=False) + out = await self.execute.asyn('true', as_root=False) if out.strip(): raise TargetStableError('The shell seems to not be functional and adds content to stderr: {}'.format(out)) def disconnect(self): connections = self._conn.get_all_values() - for conn in connections: + for conn in itertools.chain(connections, self._unused_conns): conn.close() + if self._async_pool is not None: + self._async_pool.__exit__(None, None, None) def get_connection(self, timeout=None): if self.conn_cls is None: @@ -396,11 +462,12 @@ def get_connection(self, timeout=None): def wait_boot_complete(self, timeout=10): raise NotImplementedError() - def setup(self, executables=None): - self._setup_shutils() + @asyn.asyncf + async def setup(self, executables=None): + await self._setup_shutils.asyn() for host_exe in (executables or []): # pylint: disable=superfluous-parens - self.install(host_exe) + await self.install.asyn(host_exe) # Check for platform dependent setup procedures self.platform.setup(self) @@ -408,7 +475,7 @@ def setup(self, executables=None): # Initialize modules which requires Buxybox (e.g. shutil dependent tasks) self._update_modules('setup') - self.execute('mkdir -p {}'.format(quote(self._file_transfer_cache))) + await self.execute.asyn('mkdir -p {}'.format(quote(self._file_transfer_cache))) def reboot(self, hard=False, connect=True, timeout=180): if hard: @@ -437,8 +504,8 @@ def reboot(self, hard=False, connect=True, timeout=180): # file transfer - @contextmanager - def _xfer_cache_path(self, name): + @asyn.asynccontextmanager + async def _xfer_cache_path(self, name): """ Context manager to provide a unique path in the transfer cache with the basename of the given name. @@ -454,15 +521,16 @@ def _xfer_cache_path(self, name): check_rm = False try: - self.makedirs(folder) + await self.makedirs.asyn(folder) # Don't check the exit code as the folder might not even exist # before this point, if creating it failed check_rm = True yield self.path.join(folder, name) finally: - self.execute('rm -rf -- {}'.format(quote(folder)), check_exit_code=check_rm) + await self.execute.asyn('rm -rf -- {}'.format(quote(folder)), check_exit_code=check_rm) - def _prepare_xfer(self, action, sources, dest, pattern=None, as_root=False): + @asyn.asyncf + async def _prepare_xfer(self, action, sources, dest, pattern=None, as_root=False): """ Check the sanity of sources and destination and prepare the ground for transfering multiple sources. @@ -603,28 +671,37 @@ def rewrite_dst(src, dst): for src in sources }) + @asyn.asyncf @call_conn - def push(self, source, dest, as_root=False, timeout=None, globbing=False): # pylint: disable=arguments-differ + async def push(self, source, dest, as_root=False, timeout=None, globbing=False): # pylint: disable=arguments-differ source = str(source) dest = str(dest) sources = glob.glob(source) if globbing else [source] - mapping = self._prepare_xfer('push', sources, dest, pattern=source if globbing else None, as_root=as_root) + mapping = await self._prepare_xfer.asyn('push', sources, dest, pattern=source if globbing else None, as_root=as_root) def do_push(sources, dest): + for src in sources: + self.async_manager.track_access( + asyn.PathAccess(namespace='host', path=src, mode='r') + ) + self.async_manager.track_access( + asyn.PathAccess(namespace='target', path=dest, mode='w') + ) return self.conn.push(sources, dest, timeout=timeout) if as_root: for sources, dest in mapping.items(): for source in sources: - with self._xfer_cache_path(source) as device_tempfile: + async with self._xfer_cache_path(source) as device_tempfile: do_push([source], device_tempfile) - self.execute("mv -f -- {} {}".format(quote(device_tempfile), quote(dest)), as_root=True) + await self.execute.asyn("mv -f -- {} {}".format(quote(device_tempfile), quote(dest)), as_root=True) else: for sources, dest in mapping.items(): do_push(sources, dest) - def _expand_glob(self, pattern, **kwargs): + @asyn.asyncf + async def _expand_glob(self, pattern, **kwargs): """ Expand the given path globbing pattern on the target using the shell globbing. @@ -657,20 +734,21 @@ def _expand_glob(self, pattern, **kwargs): cmd = '{} sh -c {} 2>/dev/null'.format(quote(self.busybox), quote(cmd)) # On some shells, match failure will make the command "return" a # non-zero code, even though the command was not actually called - result = self.execute(cmd, strip_colors=False, check_exit_code=False, **kwargs) + result = await self.execute.asyn(cmd, strip_colors=False, check_exit_code=False, **kwargs) paths = result.splitlines() if not paths: raise TargetStableError('No file matching: {}'.format(pattern)) return paths + @asyn.asyncf @call_conn - def pull(self, source, dest, as_root=False, timeout=None, globbing=False, via_temp=False): # pylint: disable=arguments-differ + async def pull(self, source, dest, as_root=False, timeout=None, globbing=False, via_temp=False): # pylint: disable=arguments-differ source = str(source) dest = str(dest) if globbing: - sources = self._expand_glob(source, as_root=as_root) + sources = await self._expand_glob.asyn(source, as_root=as_root) else: sources = [source] @@ -678,23 +756,31 @@ def pull(self, source, dest, as_root=False, timeout=None, globbing=False, via_te # so use a temporary copy instead. via_temp |= as_root - mapping = self._prepare_xfer('pull', sources, dest, pattern=source if globbing else None, as_root=as_root) + mapping = await self._prepare_xfer.asyn('pull', sources, dest, pattern=source if globbing else None, as_root=as_root) def do_pull(sources, dest): + for src in sources: + self.async_manager.track_access( + asyn.PathAccess(namespace='target', path=src, mode='r') + ) + self.async_manager.track_access( + asyn.PathAccess(namespace='host', path=dest, mode='w') + ) self.conn.pull(sources, dest, timeout=timeout) if via_temp: for sources, dest in mapping.items(): for source in sources: - with self._xfer_cache_path(source) as device_tempfile: - self.execute("cp -r -- {} {}".format(quote(source), quote(device_tempfile)), as_root=as_root) - self.execute("{} chmod 0644 -- {}".format(self.busybox, quote(device_tempfile)), as_root=as_root) + async with self._xfer_cache_path(source) as device_tempfile: + await self.execute.asyn("cp -r -- {} {}".format(quote(source), quote(device_tempfile)), as_root=as_root) + await self.execute.asyn("{} chmod 0644 -- {}".format(self.busybox, quote(device_tempfile)), as_root=as_root) do_pull([device_tempfile], dest) else: for sources, dest in mapping.items(): do_pull(sources, dest) - def get_directory(self, source_dir, dest, as_root=False): + @asyn.asyncf + async def get_directory(self, source_dir, dest, as_root=False): """ Pull a directory from the device, after compressing dir """ # Create all file names tar_file_name = source_dir.lstrip(self.path.sep).replace(self.path.sep, '.') @@ -708,12 +794,12 @@ def get_directory(self, source_dir, dest, as_root=False): tar_file_cm = self._xfer_cache_path if as_root else nullcontext # Does the folder exist? - self.execute('ls -la {}'.format(quote(source_dir)), as_root=as_root) + await self.execute.asyn('ls -la {}'.format(quote(source_dir)), as_root=as_root) - with tar_file_cm(tar_file_name) as tar_file_name: + async with tar_file_cm(tar_file_name) as tar_file_name: # Try compressing the folder try: - self.execute('{} tar -cvf {} {}'.format( + await self.execute.asyn('{} tar -cvf {} {}'.format( quote(self.busybox), quote(tar_file_name), quote(source_dir) ), as_root=as_root) except TargetStableError: @@ -722,7 +808,7 @@ def get_directory(self, source_dir, dest, as_root=False): # Pull the file if not os.path.exists(dest): os.mkdir(dest) - self.pull(tar_file_name, tmpfile) + await self.pull.asyn(tar_file_name, tmpfile) # Decompress with tarfile.open(tmpfile, 'r') as f: f.extractall(outdir) @@ -743,8 +829,41 @@ def _prepare_cmd(self, command, force_locale): return command + class _BrokenConnection(Exception): + pass + + @asyn.asyncf @call_conn - def execute(self, command, timeout=None, check_exit_code=True, + async def _execute_async(self, *args, **kwargs): + execute = functools.partial( + self._execute, + *args, **kwargs + ) + pool = self._async_pool + + if pool is None: + return execute() + else: + + def thread_f(): + # If we cannot successfully connect from the thread, it might + # mean that something external opened a connection on the + # target, so we just revert to the blocking path. + try: + self.conn + except Exception: + raise self._BrokenConnection + else: + return execute() + + loop = asyncio.get_running_loop() + try: + return await loop.run_in_executor(pool, thread_f) + except self._BrokenConnection: + return execute() + + @call_conn + def _execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True, will_succeed=False, force_locale='C'): @@ -753,9 +872,15 @@ def execute(self, command, timeout=None, check_exit_code=True, check_exit_code=check_exit_code, as_root=as_root, strip_colors=strip_colors, will_succeed=will_succeed) + execute = asyn._AsyncPolymorphicFunction( + asyn=_execute_async.asyn, + blocking=_execute, + ) + @call_conn def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False, force_locale='C', timeout=None): + conn = self.conn command = self._prepare_cmd(command, force_locale) bg_cmd = self.conn.background(command, stdout, stderr, as_root) if timeout is not None: @@ -837,32 +962,43 @@ def kick_off(self, command, as_root=False): # sysfs interaction - def read_value(self, path, kind=None): - output = self.execute('cat {}'.format(quote(path)), as_root=self.needs_su).strip() # pylint: disable=E1103 + @asyn.asyncf + async def read_value(self, path, kind=None): + self.async_manager.track_access( + asyn.PathAccess(namespace='target', path=path, mode='r') + ) + output = await self.execute.asyn('cat {}'.format(quote(path)), as_root=self.needs_su) # pylint: disable=E1103 + output = output.strip() if kind: return kind(output) else: return output - def read_int(self, path): - return self.read_value(path, kind=integer) + @asyn.asyncf + async def read_int(self, path): + return await self.read_value.asyn(path, kind=integer) - def read_bool(self, path): - return self.read_value(path, kind=boolean) + @asyn.asyncf + async def read_bool(self, path): + return await self.read_value.asyn(path, kind=boolean) - @contextmanager - def revertable_write_value(self, path, value, verify=True): + @asyn.asynccontextmanager + async def revertable_write_value(self, path, value, verify=True): orig_value = self.read_value(path) try: - self.write_value(path, value, verify) + await self.write_value.asyn(path, value, verify) yield finally: - self.write_value(path, orig_value, verify) + await self.write_value.asyn(path, orig_value, verify) def batch_revertable_write_value(self, kwargs_list): return batch_contextmanager(self.revertable_write_value, kwargs_list) - def write_value(self, path, value, verify=True): + @asyn.asyncf + async def write_value(self, path, value, verify=True): + self.async_manager.track_access( + asyn.PathAccess(namespace='target', path=path, mode='w') + ) value = str(value) if verify: @@ -891,7 +1027,7 @@ def write_value(self, path, value, verify=True): cmd = cmd.format(busybox=quote(self.busybox), path=quote(path), value=quote(value)) try: - self.execute(cmd, check_exit_code=True, as_root=True) + await self.execute.asyn(cmd, check_exit_code=True, as_root=True) except TargetCalledProcessError as e: if e.returncode == 10: raise TargetStableError('Could not write "{value}" to {path}: {e.output}'.format( @@ -942,22 +1078,26 @@ def ps(self, **kwargs): # files - def makedirs(self, path, as_root=False): - self.execute('mkdir -p {}'.format(quote(path)), as_root=as_root) + @asyn.asyncf + async def makedirs(self, path, as_root=False): + await self.execute.asyn('mkdir -p {}'.format(quote(path)), as_root=as_root) - def file_exists(self, filepath): + @asyn.asyncf + async def file_exists(self, filepath): command = 'if [ -e {} ]; then echo 1; else echo 0; fi' - output = self.execute(command.format(quote(filepath)), as_root=self.is_rooted) + output = await self.execute.asyn(command.format(quote(filepath)), as_root=self.is_rooted) return boolean(output.strip()) - def directory_exists(self, filepath): - output = self.execute('if [ -d {} ]; then echo 1; else echo 0; fi'.format(quote(filepath))) + @asyn.asyncf + async def directory_exists(self, filepath): + output = await self.execute.asyn('if [ -d {} ]; then echo 1; else echo 0; fi'.format(quote(filepath))) # output from ssh my contain part of the expression in the buffer, # split out everything except the last word. return boolean(output.split()[-1]) # pylint: disable=maybe-no-member - def list_file_systems(self): - output = self.execute('mount') + @asyn.asyncf + async def list_file_systems(self): + output = await self.execute.asyn('mount') fstab = [] for line in output.split('\n'): line = line.strip() @@ -972,31 +1112,44 @@ def list_file_systems(self): fstab.append(FstabEntry(*line.split())) return fstab - def list_directory(self, path, as_root=False): + @asyn.asyncf + async def list_directory(self, path, as_root=False): + self.async_manager.track_access( + asyn.PathAccess(namespace='target', path=path, mode='r') + ) + return await self._list_directory(path, as_root=as_root) + + def _list_directory(self, path, as_root=False): raise NotImplementedError() def get_workpath(self, name): return self.path.join(self.working_directory, name) - def tempfile(self, prefix='', suffix=''): - names = tempfile._get_candidate_names() # pylint: disable=W0212 - for _ in range(tempfile.TMP_MAX): - name = next(names) - path = self.get_workpath(prefix + name + suffix) - if not self.file_exists(path): - return path - raise IOError('No usable temporary filename found') + @asyn.asyncf + async def tempfile(self, prefix='', suffix=''): + name = '{prefix}_{uuid}_{suffix}'.format( + prefix=prefix, + uuid=uuid.uuid4().hex, + suffix=suffix, + ) + path = self.get_workpath(name) + if (await self.file_exists.asyn(path)): + raise FileExistsError('Path already exists on the target: {}'.format(path)) + else: + return path - def remove(self, path, as_root=False): - self.execute('rm -rf -- {}'.format(quote(path)), as_root=as_root) + @asyn.asyncf + async def remove(self, path, as_root=False): + await self.execute.asyn('rm -rf -- {}'.format(quote(path)), as_root=as_root) # misc def core_cpus(self, core): return [i for i, c in enumerate(self.core_names) if c == core] - def list_online_cpus(self, core=None): + @asyn.asyncf + async def list_online_cpus(self, core=None): path = self.path.join('/sys/devices/system/cpu/online') - output = self.read_value(path) + output = await self.read_value.asyn(path) all_online = ranges_to_list(output) if core: cpus = self.core_cpus(core) @@ -1006,13 +1159,16 @@ def list_online_cpus(self, core=None): else: return all_online - def list_offline_cpus(self): - online = self.list_online_cpus() + @asyn.asyncf + async def list_offline_cpus(self): + online = await self.list_online_cpus.asyn() return [c for c in range(self.number_of_cpus) if c not in online] - def getenv(self, variable): - return self.execute('echo ${}'.format(variable)).rstrip('\r\n') + @asyn.asyncf + async def getenv(self, variable): + var = await self.execute.asyn('printf "%s" ${}'.format(variable)) + return var.rstrip('\r\n') def capture_screen(self, filepath): raise NotImplementedError() @@ -1023,32 +1179,36 @@ def install(self, filepath, timeout=None, with_name=None): def uninstall(self, name): raise NotImplementedError() - def get_installed(self, name, search_system_binaries=True): + @asyn.asyncf + async def get_installed(self, name, search_system_binaries=True): # Check user installed binaries first if self.file_exists(self.executables_directory): - if name in self.list_directory(self.executables_directory): + if name in (await self.list_directory.asyn(self.executables_directory)): return self.path.join(self.executables_directory, name) # Fall back to binaries in PATH if search_system_binaries: - for path in self.getenv('PATH').split(self.path.pathsep): + PATH = await self.getenv.asyn('PATH') + for path in PATH.split(self.path.pathsep): try: - if name in self.list_directory(path): + if name in (await self.list_directory.asyn(path)): return self.path.join(path, name) except TargetStableError: pass # directory does not exist or no executable permissions which = get_installed - def install_if_needed(self, host_path, search_system_binaries=True, timeout=None): + @asyn.asyncf + async def install_if_needed(self, host_path, search_system_binaries=True, timeout=None): - binary_path = self.get_installed(os.path.split(host_path)[1], + binary_path = await self.get_installed.asyn(os.path.split(host_path)[1], search_system_binaries=search_system_binaries) if not binary_path: - binary_path = self.install(host_path, timeout=timeout) + binary_path = await self.install.asyn(host_path, timeout=timeout) return binary_path - def is_installed(self, name): - return bool(self.get_installed(name)) + @asyn.asyncf + async def is_installed(self, name): + return bool(await self.get_installed.asyn(name)) def bin(self, name): return self._installed_binaries.get(name, name) @@ -1056,30 +1216,29 @@ def bin(self, name): def has(self, modname): return hasattr(self, identifier(modname)) - def lsmod(self): - lines = self.execute('lsmod').splitlines() + @asyn.asyncf + async def lsmod(self): + lines = (await self.execute.asyn('lsmod')).splitlines() entries = [] for line in lines[1:]: # first line is the header if not line.strip(): continue - parts = line.split() - name = parts[0] - size = int(parts[1]) - use_count = int(parts[2]) - if len(parts) > 3: - used_by = ''.join(parts[3:]).split(',') + name, size, use_count, *remainder = line.split() + if remainder: + used_by = ''.join(remainder).split(',') else: used_by = [] entries.append(LsmodEntry(name, size, use_count, used_by)) return entries - def insmod(self, path): + @asyn.asyncf + async def insmod(self, path): target_path = self.get_workpath(os.path.basename(path)) - self.push(path, target_path) - self.execute('insmod {}'.format(quote(target_path)), as_root=True) - + await self.push.asyn(path, target_path) + await self.execute.asyn('insmod {}'.format(quote(target_path)), as_root=True) - def extract(self, path, dest=None): + @asyn.asyncf + async def extract(self, path, dest=None): """ Extract the specified on-target file. The extraction method to be used (unzip, gunzip, bunzip2, or tar) will be based on the file's extension. @@ -1100,27 +1259,32 @@ def extract(self, path, dest=None): for ending in ['.tar.gz', '.tar.bz', '.tar.bz2', '.tgz', '.tbz', '.tbz2']: if path.endswith(ending): - return self._extract_archive(path, 'tar xf {} -C {}', dest) + return await self._extract_archive(path, 'tar xf {} -C {}', dest) ext = self.path.splitext(path)[1] if ext in ['.bz', '.bz2']: - return self._extract_file(path, 'bunzip2 -f {}', dest) + return await self._extract_file(path, 'bunzip2 -f {}', dest) elif ext == '.gz': - return self._extract_file(path, 'gunzip -f {}', dest) + return await self._extract_file(path, 'gunzip -f {}', dest) elif ext == '.zip': - return self._extract_archive(path, 'unzip {} -d {}', dest) + return await self._extract_archive(path, 'unzip {} -d {}', dest) else: raise ValueError('Unknown compression format: {}'.format(ext)) - def sleep(self, duration): + @asyn.asyncf + async def sleep(self, duration): timeout = duration + 10 - self.execute('sleep {}'.format(duration), timeout=timeout) + await self.execute.asyn('sleep {}'.format(duration), timeout=timeout) - def read_tree_tar_flat(self, path, depth=1, check_exit_code=True, + @asyn.asyncf + async def read_tree_tar_flat(self, path, depth=1, check_exit_code=True, decode_unicode=True, strip_null_chars=True): + self.async_manager.track_access( + asyn.PathAccess(namespace='target', path=path, mode='r') + ) command = 'read_tree_tgz_b64 {} {} {}'.format(quote(path), depth, quote(self.working_directory)) - output = self._execute_util(command, as_root=self.is_rooted, + output = await self._execute_util.asyn(command, as_root=self.is_rooted, check_exit_code=check_exit_code) result = {} @@ -1153,9 +1317,13 @@ def read_tree_tar_flat(self, path, depth=1, check_exit_code=True, return result - def read_tree_values_flat(self, path, depth=1, check_exit_code=True): + @asyn.asyncf + async def read_tree_values_flat(self, path, depth=1, check_exit_code=True): + self.async_manager.track_access( + asyn.PathAccess(namespace='target', path=path, mode='r') + ) command = 'read_tree_values {} {}'.format(quote(path), depth) - output = self._execute_util(command, as_root=self.is_rooted, + output = await self._execute_util.asyn(command, as_root=self.is_rooted, check_exit_code=check_exit_code) accumulator = defaultdict(list) @@ -1168,7 +1336,8 @@ def read_tree_values_flat(self, path, depth=1, check_exit_code=True): result = {k: '\n'.join(v).strip() for k, v in accumulator.items()} return result - def read_tree_values(self, path, depth=1, dictcls=dict, + @asyn.asyncf + async def read_tree_values(self, path, depth=1, dictcls=dict, check_exit_code=True, tar=False, decode_unicode=True, strip_null_chars=True): """ @@ -1187,9 +1356,9 @@ def read_tree_values(self, path, depth=1, dictcls=dict, :returns: a tree-like dict with the content of files as leafs """ if not tar: - value_map = self.read_tree_values_flat(path, depth, check_exit_code) + value_map = await self.read_tree_values_flat.asyn(path, depth, check_exit_code) else: - value_map = self.read_tree_tar_flat(path, depth, check_exit_code, + value_map = await self.read_tree_tar_flat.asyn(path, depth, check_exit_code, decode_unicode, strip_null_chars) return _build_path_tree(value_map, path, self.path.sep, dictcls) @@ -1208,42 +1377,47 @@ def install_module(self, mod, **params): # internal methods - def _setup_shutils(self): + @asyn.asyncf + async def _setup_shutils(self): shutils_ifile = os.path.join(PACKAGE_BIN_DIRECTORY, 'scripts', 'shutils.in') - tmp_dir = tempfile.mkdtemp() - shutils_ofile = os.path.join(tmp_dir, 'shutils') with open(shutils_ifile) as fh: lines = fh.readlines() - with open(shutils_ofile, 'w') as ofile: - for line in lines: - line = line.replace("__DEVLIB_BUSYBOX__", self.busybox) - ofile.write(line) - self._shutils = self.install(shutils_ofile) - os.remove(shutils_ofile) - os.rmdir(tmp_dir) - + with tempfile.TemporaryDirectory() as folder: + shutils_ofile = os.path.join(folder, 'shutils') + with open(shutils_ofile, 'w') as ofile: + for line in lines: + line = line.replace("__DEVLIB_BUSYBOX__", self.busybox) + ofile.write(line) + self._shutils = await self.install.asyn(shutils_ofile) + + @asyn.asyncf @call_conn - def _execute_util(self, command, timeout=None, check_exit_code=True, as_root=False): + async def _execute_util(self, command, timeout=None, check_exit_code=True, as_root=False): command = '{} sh {} {}'.format(quote(self.busybox), quote(self.shutils), command) - return self.conn.execute(command, timeout, check_exit_code, as_root) + return await self.execute.asyn( + command, + timeout=timeout, + check_exit_code=check_exit_code, + as_root=as_root + ) - def _extract_archive(self, path, cmd, dest=None): + async def _extract_archive(self, path, cmd, dest=None): cmd = '{} ' + cmd # busybox if dest: extracted = dest else: extracted = self.path.dirname(path) cmdtext = cmd.format(quote(self.busybox), quote(path), quote(extracted)) - self.execute(cmdtext) + await self.execute.asyn(cmdtext) return extracted - def _extract_file(self, path, cmd, dest=None): + async def _extract_file(self, path, cmd, dest=None): cmd = '{} ' + cmd # busybox cmdtext = cmd.format(quote(self.busybox), quote(path)) - self.execute(cmdtext) + await self.execute.asyn(cmdtext) extracted = self.path.splitext(path)[0] if dest: - self.execute('mv -f {} {}'.format(quote(extracted), quote(dest))) + await self.execute.asyn('mv -f {} {}'.format(quote(extracted), quote(dest))) if dest.endswith('/'): extracted = self.path.join(dest, self.path.basename(extracted)) else: @@ -1287,7 +1461,8 @@ def _install_module(self, mod, **params): def _resolve_paths(self): raise NotImplementedError() - def is_network_connected(self): + @asyn.asyncf + async def is_network_connected(self): self.logger.debug('Checking for internet connectivity...') timeout_s = 5 @@ -1302,7 +1477,7 @@ def is_network_connected(self): attempts = 5 for _ in range(attempts): try: - self.execute(command) + await self.execute.asyn(command) return True except TargetStableError as e: err = str(e).lower() @@ -1392,29 +1567,33 @@ def kick_off(self, command, as_root=False): command = 'sh -c {} 1>/dev/null 2>/dev/null &'.format(quote(command)) return self.conn.execute(command, as_root=as_root) - def get_pids_of(self, process_name): + @asyn.asyncf + async def get_pids_of(self, process_name): """Returns a list of PIDs of all processes with the specified name.""" # result should be a column of PIDs with the first row as "PID" header - result = self.execute('ps -C {} -o pid'.format(quote(process_name)), # NOQA - check_exit_code=False).strip().split() + result = await self.execute.asyn('ps -C {} -o pid'.format(quote(process_name)), # NOQA + check_exit_code=False) + result = result.strip().split() if len(result) >= 2: # at least one row besides the header return list(map(int, result[1:])) else: return [] - def ps(self, threads=False, **kwargs): + @asyn.asyncf + async def ps(self, threads=False, **kwargs): ps_flags = '-eo' if threads: ps_flags = '-eLo' command = 'ps {} user,pid,tid,ppid,vsize,rss,wchan,pcpu,state,fname'.format(ps_flags) - lines = iter(convert_new_lines(self.execute(command)).split('\n')) - next(lines) # header + out = await self.execute.asyn(command) result = [] - for line in lines: + lines = convert_new_lines(out).splitlines() + # Skip header + for line in lines[1:]: parts = re.split(r'\s+', line, maxsplit=9) - if parts and parts != ['']: + if parts: result.append(PsEntry(*(parts[0:1] + list(map(int, parts[1:6])) + parts[6:]))) if not kwargs: @@ -1426,34 +1605,37 @@ def ps(self, threads=False, **kwargs): filtered_result.append(entry) return filtered_result - def list_directory(self, path, as_root=False): - contents = self.execute('ls -1 {}'.format(quote(path)), as_root=as_root) + async def _list_directory(self, path, as_root=False): + contents = await self.execute.asyn('ls -1 {}'.format(quote(path)), as_root=as_root) return [x.strip() for x in contents.split('\n') if x.strip()] - def install(self, filepath, timeout=None, with_name=None): # pylint: disable=W0221 + @asyn.asyncf + async def install(self, filepath, timeout=None, with_name=None): # pylint: disable=W0221 destpath = self.path.join(self.executables_directory, with_name and with_name or self.path.basename(filepath)) - self.push(filepath, destpath, timeout=timeout) - self.execute('chmod a+x {}'.format(quote(destpath)), timeout=timeout) + await self.push.asyn(filepath, destpath, timeout=timeout) + await self.execute.asyn('chmod a+x {}'.format(quote(destpath)), timeout=timeout) self._installed_binaries[self.path.basename(destpath)] = destpath return destpath - def uninstall(self, name): + @asyn.asyncf + async def uninstall(self, name): path = self.path.join(self.executables_directory, name) - self.remove(path) + await self.remove.asyn(path) - def capture_screen(self, filepath): - if not self.is_installed('scrot'): + @asyn.asyncf + async def capture_screen(self, filepath): + if not (await self.is_installed.asyn('scrot')): self.logger.debug('Could not take screenshot as scrot is not installed.') return try: - tmpfile = self.tempfile() + tmpfile = await self.tempfile.asyn() cmd = 'DISPLAY=:0.0 scrot {} && {} date -u -Iseconds' - ts = self.execute(cmd.format(quote(tmpfile), quote(self.busybox))).strip() + ts = (await self.execute.asyn(cmd.format(quote(tmpfile), quote(self.busybox)))).strip() filepath = filepath.format(ts=ts) - self.pull(tmpfile, filepath) - self.remove(tmpfile) + await self.pull.asyn(tmpfile, filepath) + await self.remove.asyn(tmpfile) except TargetStableError as e: if "Can't open X dispay." not in e.message: raise e @@ -1599,31 +1781,39 @@ def __setstate__(self, dct): self.__dict__.update(dct) self._init_logcat_lock() - def reset(self, fastboot=False): # pylint: disable=arguments-differ + @asyn.asyncf + async def reset(self, fastboot=False): # pylint: disable=arguments-differ try: - self.execute('reboot {}'.format(fastboot and 'fastboot' or ''), + await self.execute.asyn('reboot {}'.format(fastboot and 'fastboot' or ''), as_root=self.needs_su, timeout=2) except (DevlibTransientError, subprocess.CalledProcessError): # on some targets "reboot" doesn't return gracefully pass self.conn.connected_as_root = None - def wait_boot_complete(self, timeout=10): + @asyn.asyncf + async def wait_boot_complete(self, timeout=10): start = time.time() - boot_completed = boolean(self.getprop('sys.boot_completed')) + boot_completed = boolean(await self.getprop.asyn('sys.boot_completed')) while not boot_completed and timeout >= time.time() - start: time.sleep(5) - boot_completed = boolean(self.getprop('sys.boot_completed')) + boot_completed = boolean(await self.getprop.asyn('sys.boot_completed')) if not boot_completed: # Raise a TargetStableError as this usually happens because of # an issue with Android more than a timeout that is too small. raise TargetStableError('Connected but Android did not fully boot.') - def connect(self, timeout=30, check_boot_completed=True): # pylint: disable=arguments-differ + @asyn.asyncf + async def connect(self, timeout=30, check_boot_completed=True, max_async=50): # pylint: disable=arguments-differ device = self.connection_settings.get('device') - super(AndroidTarget, self).connect(timeout=timeout, check_boot_completed=check_boot_completed) + await super(AndroidTarget, self).connect.asyn( + timeout=timeout, + check_boot_completed=check_boot_completed, + max_async=max_async, + ) - def kick_off(self, command, as_root=None): + @asyn.asyncf + async def kick_off(self, command, as_root=None): """ Like execute but closes adb session and returns immediately, leaving the command running on the device (this is different from execute(background=True) which keeps adb connection open and returns @@ -1633,11 +1823,12 @@ def kick_off(self, command, as_root=None): as_root = self.needs_su try: command = 'cd {} && {} nohup {} &'.format(quote(self.working_directory), quote(self.busybox), command) - self.execute(command, timeout=1, as_root=as_root) + await self.execute.asyn(command, timeout=1, as_root=as_root) except TimeoutError: pass - def __setup_list_directory(self): + @asyn.asyncf + async def __setup_list_directory(self): # In at least Linaro Android 16.09 (which was their first Android 7 release) and maybe # AOSP 7.0 as well, the ls command was changed. # Previous versions default to a single column listing, which is nice and easy to parse. @@ -1646,44 +1837,48 @@ def __setup_list_directory(self): # so we try the new version, and if it fails we use the old version. self.ls_command = 'ls -1' try: - self.execute('ls -1 {}'.format(quote(self.working_directory)), as_root=False) + await self.execute.asyn('ls -1 {}'.format(quote(self.working_directory)), as_root=False) except TargetStableError: self.ls_command = 'ls' - def list_directory(self, path, as_root=False): + async def _list_directory(self, path, as_root=False): if self.ls_command == '': - self.__setup_list_directory() - contents = self.execute('{} {}'.format(self.ls_command, quote(path)), as_root=as_root) + await self.__setup_list_directory.asyn() + contents = await self.execute.asyn('{} {}'.format(self.ls_command, quote(path)), as_root=as_root) return [x.strip() for x in contents.split('\n') if x.strip()] - def install(self, filepath, timeout=None, with_name=None): # pylint: disable=W0221 + @asyn.asyncf + async def install(self, filepath, timeout=None, with_name=None): # pylint: disable=W0221 ext = os.path.splitext(filepath)[1].lower() if ext == '.apk': - return self.install_apk(filepath, timeout) + return await self.install_apk.asyn(filepath, timeout) else: - return self.install_executable(filepath, with_name, timeout) + return await self.install_executable.asyn(filepath, with_name, timeout) - def uninstall(self, name): - if self.package_is_installed(name): - self.uninstall_package(name) + @asyn.asyncf + async def uninstall(self, name): + if await self.package_is_installed.asyn(name): + await self.uninstall_package.asyn(name) else: - self.uninstall_executable(name) + await self.uninstall_executable.asyn(name) - def get_pids_of(self, process_name): + @asyn.asyncf + async def get_pids_of(self, process_name): result = [] search_term = process_name[-15:] - for entry in self.ps(): + for entry in await self.ps.asyn(): if search_term in entry.name: result.append(entry.pid) return result - def ps(self, threads=False, **kwargs): + @asyn.asyncf + async def ps(self, threads=False, **kwargs): maxsplit = 9 if threads else 8 command = 'ps' if threads: command = 'ps -AT' - lines = iter(convert_new_lines(self.execute(command)).split('\n')) + lines = iter(convert_new_lines(await self.execute.asyn(command)).split('\n')) next(lines) # header result = [] for line in lines: @@ -1712,37 +1907,42 @@ def ps(self, threads=False, **kwargs): filtered_result.append(entry) return filtered_result - def capture_screen(self, filepath): + @asyn.asyncf + async def capture_screen(self, filepath): on_device_file = self.path.join(self.working_directory, 'screen_capture.png') cmd = 'screencap -p {} && {} date -u -Iseconds' - ts = self.execute(cmd.format(quote(on_device_file), quote(self.busybox))).strip() + ts = (await self.execute.asyn(cmd.format(quote(on_device_file), quote(self.busybox)))).strip() filepath = filepath.format(ts=ts) - self.pull(on_device_file, filepath) - self.remove(on_device_file) + await self.pull.asyn(on_device_file, filepath) + await self.remove.asyn(on_device_file) # Android-specific - def input_tap(self, x, y): + @asyn.asyncf + async def input_tap(self, x, y): command = 'input tap {} {}' - self.execute(command.format(x, y)) + await self.execute.asyn(command.format(x, y)) - def input_tap_pct(self, x, y): + @asyn.asyncf + async def input_tap_pct(self, x, y): width, height = self.screen_resolution x = (x * width) // 100 y = (y * height) // 100 - self.input_tap(x, y) + await self.input_tap.asyn(x, y) - def input_swipe(self, x1, y1, x2, y2): + @asyn.asyncf + async def input_swipe(self, x1, y1, x2, y2): """ Issue a swipe on the screen from (x1, y1) to (x2, y2) Uses absolute screen positions """ command = 'input swipe {} {} {} {}' - self.execute(command.format(x1, y1, x2, y2)) + await self.execute.asyn(command.format(x1, y1, x2, y2)) - def input_swipe_pct(self, x1, y1, x2, y2): + @asyn.asyncf + async def input_swipe_pct(self, x1, y1, x2, y2): """ Issue a swipe on the screen from (x1, y1) to (x2, y2) Uses percent-based positions @@ -1754,38 +1954,43 @@ def input_swipe_pct(self, x1, y1, x2, y2): x2 = (x2 * width) // 100 y2 = (y2 * height) // 100 - self.input_swipe(x1, y1, x2, y2) + await self.input_swipe.asyn(x1, y1, x2, y2) - def swipe_to_unlock(self, direction="diagonal"): + @asyn.asyncf + async def swipe_to_unlock(self, direction="diagonal"): width, height = self.screen_resolution if direction == "diagonal": start = 100 stop = width - start swipe_height = height * 2 // 3 - self.input_swipe(start, swipe_height, stop, 0) + await self.input_swipe.asyn(start, swipe_height, stop, 0) elif direction == "horizontal": swipe_height = height * 2 // 3 start = 100 stop = width - start - self.input_swipe(start, swipe_height, stop, swipe_height) + await self.input_swipe.asyn(start, swipe_height, stop, swipe_height) elif direction == "vertical": swipe_middle = width / 2 swipe_height = height * 2 // 3 - self.input_swipe(swipe_middle, swipe_height, swipe_middle, 0) + await self.input_swipe.asyn(swipe_middle, swipe_height, swipe_middle, 0) else: raise TargetStableError("Invalid swipe direction: {}".format(direction)) - def getprop(self, prop=None): - props = AndroidProperties(self.execute('getprop')) + @asyn.asyncf + async def getprop(self, prop=None): + props = AndroidProperties(await self.execute.asyn('getprop')) if prop: return props[prop] return props - def capture_ui_hierarchy(self, filepath): + @asyn.asyncf + async def capture_ui_hierarchy(self, filepath): on_target_file = self.get_workpath('screen_capture.xml') - self.execute('uiautomator dump {}'.format(on_target_file)) - self.pull(on_target_file, filepath) - self.remove(on_target_file) + try: + await self.execute.asyn('uiautomator dump {}'.format(on_target_file)) + await self.pull.asyn(on_target_file, filepath) + finally: + await self.remove.asyn(on_target_file) parsed_xml = xml.dom.minidom.parse(filepath) with open(filepath, 'w') as f: @@ -1794,26 +1999,31 @@ def capture_ui_hierarchy(self, filepath): else: f.write(parsed_xml.toprettyxml().encode('utf-8')) - def is_installed(self, name): - return super(AndroidTarget, self).is_installed(name) or self.package_is_installed(name) + @asyn.asyncf + async def is_installed(self, name): + return (await super(AndroidTarget, self).is_installed.asyn(name)) or (await self.package_is_installed.asyn(name)) - def package_is_installed(self, package_name): - return package_name in self.list_packages() + @asyn.asyncf + async def package_is_installed(self, package_name): + return package_name in (await self.list_packages.asyn()) - def list_packages(self): - output = self.execute('pm list packages') + @asyn.asyncf + async def list_packages(self): + output = await self.execute.asyn('pm list packages') output = output.replace('package:', '') return output.split() - def get_package_version(self, package): - output = self.execute('dumpsys package {}'.format(quote(package))) + @asyn.asyncf + async def get_package_version(self, package): + output = await self.execute.asyn('dumpsys package {}'.format(quote(package))) for line in convert_new_lines(output).split('\n'): if 'versionName' in line: return line.split('=', 1)[1] return None - def get_package_info(self, package): - output = self.execute('pm list packages -f {}'.format(quote(package))) + @asyn.asyncf + async def get_package_info(self, package): + output = await self.execute.asyn('pm list packages -f {}'.format(quote(package))) for entry in output.strip().split('\n'): rest, entry_package = entry.rsplit('=', 1) if entry_package != package: @@ -1821,13 +2031,15 @@ def get_package_info(self, package): _, apk_path = rest.split(':') return installed_package_info(apk_path, entry_package) - def get_sdk_version(self): + @asyn.asyncf + async def get_sdk_version(self): try: - return int(self.getprop('ro.build.version.sdk')) + return int(await self.getprop.asyn('ro.build.version.sdk')) except (ValueError, TypeError): return None - def install_apk(self, filepath, timeout=None, replace=False, allow_downgrade=False): # pylint: disable=W0221 + @asyn.asyncf + async def install_apk(self, filepath, timeout=None, replace=False, allow_downgrade=False): # pylint: disable=W0221 ext = os.path.splitext(filepath)[1].lower() if ext == '.apk': flags = [] @@ -1843,16 +2055,17 @@ def install_apk(self, filepath, timeout=None, replace=False, allow_downgrade=Fal timeout=timeout, adb_server=self.adb_server) else: dev_path = self.get_workpath(filepath.rsplit(os.path.sep, 1)[-1]) - self.push(quote(filepath), dev_path, timeout=timeout) - result = self.execute("pm install {} {}".format(' '.join(flags), quote(dev_path)), timeout=timeout) - self.remove(dev_path) + await self.push.asyn(quote(filepath), dev_path, timeout=timeout) + result = await self.execute.asyn("pm install {} {}".format(' '.join(flags), quote(dev_path)), timeout=timeout) + await self.remove.asyn(dev_path) return result else: raise TargetStableError('Can\'t install {}: unsupported format.'.format(filepath)) - def grant_package_permission(self, package, permission): + @asyn.asyncf + async def grant_package_permission(self, package, permission): try: - return self.execute('pm grant {} {}'.format(quote(package), quote(permission))) + return await self.execute.asyn('pm grant {} {}'.format(quote(package), quote(permission))) except TargetStableError as e: if 'is not a changeable permission type' in e.message: pass # Ignore if unchangeable @@ -1865,61 +2078,68 @@ def grant_package_permission(self, package, permission): else: raise - def refresh_files(self, file_list): + @asyn.asyncf + async def refresh_files(self, file_list): """ Depending on the android version and root status, determine the appropriate method of forcing a re-index of the mediaserver cache for a given list of files. """ - if self.is_rooted or self.get_sdk_version() < 24: # MM and below + if self.is_rooted or (await self.get_sdk_version.asyn()) < 24: # MM and below common_path = commonprefix(file_list, sep=self.path.sep) - self.broadcast_media_mounted(common_path, self.is_rooted) + await self.broadcast_media_mounted.asyn(common_path, self.is_rooted) else: for f in file_list: - self.broadcast_media_scan_file(f) + await self.broadcast_media_scan_file.asyn(f) - def broadcast_media_scan_file(self, filepath): + @asyn.asyncf + async def broadcast_media_scan_file(self, filepath): """ Force a re-index of the mediaserver cache for the specified file. """ command = 'am broadcast -a android.intent.action.MEDIA_SCANNER_SCAN_FILE -d {}' - self.execute(command.format(quote('file://' + filepath))) + await self.execute.asyn(command.format(quote('file://' + filepath))) - def broadcast_media_mounted(self, dirpath, as_root=False): + @asyn.asyncf + async def broadcast_media_mounted(self, dirpath, as_root=False): """ Force a re-index of the mediaserver cache for the specified directory. """ command = 'am broadcast -a android.intent.action.MEDIA_MOUNTED -d {} '\ '-n com.android.providers.media/.MediaScannerReceiver' - self.execute(command.format(quote('file://'+dirpath)), as_root=as_root) + await self.execute.asyn(command.format(quote('file://'+dirpath)), as_root=as_root) - def install_executable(self, filepath, with_name=None, timeout=None): + @asyn.asyncf + async def install_executable(self, filepath, with_name=None, timeout=None): self._ensure_executables_directory_is_writable() executable_name = with_name or os.path.basename(filepath) on_device_file = self.path.join(self.working_directory, executable_name) on_device_executable = self.path.join(self.executables_directory, executable_name) - self.push(filepath, on_device_file, timeout=timeout) + await self.push.asyn(filepath, on_device_file, timeout=timeout) if on_device_file != on_device_executable: - self.execute('cp {} {}'.format(quote(on_device_file), quote(on_device_executable)), + await self.execute.asyn('cp {} {}'.format(quote(on_device_file), quote(on_device_executable)), as_root=self.needs_su, timeout=timeout) - self.remove(on_device_file, as_root=self.needs_su) - self.execute("chmod 0777 {}".format(quote(on_device_executable)), as_root=self.needs_su) + await self.remove.asyn(on_device_file, as_root=self.needs_su) + await self.execute.asyn("chmod 0777 {}".format(quote(on_device_executable)), as_root=self.needs_su) self._installed_binaries[executable_name] = on_device_executable return on_device_executable - def uninstall_package(self, package): + @asyn.asyncf + async def uninstall_package(self, package): if isinstance(self.conn, AdbConnection): adb_command(self.adb_name, "uninstall {}".format(quote(package)), timeout=30, adb_server=self.adb_server) else: - self.execute("pm uninstall {}".format(quote(package)), timeout=30) + await self.execute.asyn("pm uninstall {}".format(quote(package)), timeout=30) - def uninstall_executable(self, executable_name): + @asyn.asyncf + async def uninstall_executable(self, executable_name): on_device_executable = self.path.join(self.executables_directory, executable_name) self._ensure_executables_directory_is_writable() - self.remove(on_device_executable, as_root=self.needs_su) + await self.remove.asyn(on_device_executable, as_root=self.needs_su) - def dump_logcat(self, filepath, filter=None, logcat_format=None, append=False, + @asyn.asyncf + async def dump_logcat(self, filepath, filter=None, logcat_format=None, append=False, timeout=60): # pylint: disable=redefined-builtin op = '>>' if append else '>' filtstr = ' -s {}'.format(quote(filter)) if filter else '' @@ -1931,18 +2151,19 @@ def dump_logcat(self, filepath, filter=None, logcat_format=None, append=False, else: dev_path = self.get_workpath('logcat') command = 'logcat {} {} {}'.format(logcat_opts, op, quote(dev_path)) - self.execute(command, timeout=timeout) - self.pull(dev_path, filepath) - self.remove(dev_path) + await self.execute.asyn(command, timeout=timeout) + await self.pull.asyn(dev_path, filepath) + await self.remove.asyn(dev_path) - def clear_logcat(self): + @asyn.asyncf + async def clear_logcat(self): locked = self.clear_logcat_lock.acquire(blocking=False) if locked: try: if isinstance(self.conn, AdbConnection): adb_command(self.adb_name, 'logcat -c', timeout=30, adb_server=self.adb_server) else: - self.execute('logcat -c', timeout=30) + await self.execute.asyn('logcat -c', timeout=30) finally: self.clear_logcat_lock.release() @@ -1957,8 +2178,9 @@ def wait_for_device(self, timeout=30): def reboot_bootloader(self, timeout=30): self.conn.reboot_bootloader() - def is_screen_on(self): - output = self.execute('dumpsys power') + @asyn.asyncf + async def is_screen_on(self): + output = await self.execute.asyn('dumpsys power') match = ANDROID_SCREEN_STATE_REGEX.search(output) if match: if 'DOZE' in match.group(1).upper(): @@ -1971,121 +2193,145 @@ def is_screen_on(self): else: raise TargetStableError('Could not establish screen state.') - def ensure_screen_is_on(self, verify=True): - if not self.is_screen_on(): + @asyn.asyncf + async def ensure_screen_is_on(self, verify=True): + if not await self.is_screen_on.asyn(): self.execute('input keyevent 26') - if verify and not self.is_screen_on(): + if verify and not await self.is_screen_on.asyn(): raise TargetStableError('Display cannot be turned on.') - def ensure_screen_is_on_and_stays(self, verify=True, mode=7): - self.ensure_screen_is_on(verify=verify) - self.set_stay_on_mode(mode) + @asyn.asyncf + async def ensure_screen_is_on_and_stays(self, verify=True, mode=7): + await self.ensure_screen_is_on.asyn(verify=verify) + await self.set_stay_on_mode.asyn(mode) - def ensure_screen_is_off(self, verify=True): + @asyn.asyncf + async def ensure_screen_is_off(self, verify=True): # Allow 2 attempts to help with cases of ambient display modes # where the first attempt will switch the display fully on. for _ in range(2): - if self.is_screen_on(): - self.execute('input keyevent 26') + if await self.is_screen_on.asyn(): + await self.execute.asyn('input keyevent 26') time.sleep(0.5) - if verify and self.is_screen_on(): + if verify and await self.is_screen_on.asyn(): msg = 'Display cannot be turned off. Is always on display enabled?' raise TargetStableError(msg) - def set_auto_brightness(self, auto_brightness): + @asyn.asyncf + async def set_auto_brightness(self, auto_brightness): cmd = 'settings put system screen_brightness_mode {}' - self.execute(cmd.format(int(boolean(auto_brightness)))) + await self.execute.asyn(cmd.format(int(boolean(auto_brightness)))) - def get_auto_brightness(self): + @asyn.asyncf + async def get_auto_brightness(self): cmd = 'settings get system screen_brightness_mode' - return boolean(self.execute(cmd).strip()) + return boolean((await self.execute.asyn(cmd)).strip()) - def set_brightness(self, value): + @asyn.asyncf + async def set_brightness(self, value): if not 0 <= value <= 255: msg = 'Invalid brightness "{}"; Must be between 0 and 255' raise ValueError(msg.format(value)) - self.set_auto_brightness(False) + await self.set_auto_brightness.asyn(False) cmd = 'settings put system screen_brightness {}' - self.execute(cmd.format(int(value))) + await self.execute.asyn(cmd.format(int(value))) - def get_brightness(self): + @asyn.asyncf + async def get_brightness(self): cmd = 'settings get system screen_brightness' - return integer(self.execute(cmd).strip()) + return integer((await self.execute.asyn(cmd)).strip()) - def set_screen_timeout(self, timeout_ms): + @asyn.asyncf + async def set_screen_timeout(self, timeout_ms): cmd = 'settings put system screen_off_timeout {}' - self.execute(cmd.format(int(timeout_ms))) + await self.execute.asyn(cmd.format(int(timeout_ms))) - def get_screen_timeout(self): + @asyn.asyncf + async def get_screen_timeout(self): cmd = 'settings get system screen_off_timeout' - return int(self.execute(cmd).strip()) + return int((await self.execute.asyn(cmd)).strip()) - def get_airplane_mode(self): + @asyn.asyncf + async def get_airplane_mode(self): cmd = 'settings get global airplane_mode_on' - return boolean(self.execute(cmd).strip()) + return boolean((await self.execute.asyn(cmd)).strip()) - def get_stay_on_mode(self): + @asyn.asyncf + async def get_stay_on_mode(self): cmd = 'settings get global stay_on_while_plugged_in' - return int(self.execute(cmd).strip()) + return int((await self.execute.asyn(cmd)).strip()) - def set_airplane_mode(self, mode): - root_required = self.get_sdk_version() > 23 + @asyn.asyncf + async def set_airplane_mode(self, mode): + root_required = await self.get_sdk_version.asyn() > 23 if root_required and not self.is_rooted: raise TargetStableError('Root is required to toggle airplane mode on Android 7+') mode = int(boolean(mode)) cmd = 'settings put global airplane_mode_on {}' - self.execute(cmd.format(mode)) - self.execute('am broadcast -a android.intent.action.AIRPLANE_MODE ' + await self.execute.asyn(cmd.format(mode)) + await self.execute.asyn('am broadcast -a android.intent.action.AIRPLANE_MODE ' '--ez state {}'.format(mode), as_root=root_required) - def get_auto_rotation(self): + @asyn.asyncf + async def get_auto_rotation(self): cmd = 'settings get system accelerometer_rotation' - return boolean(self.execute(cmd).strip()) + return boolean((await self.execute.asyn(cmd)).strip()) - def set_auto_rotation(self, autorotate): + @asyn.asyncf + async def set_auto_rotation(self, autorotate): cmd = 'settings put system accelerometer_rotation {}' - self.execute(cmd.format(int(boolean(autorotate)))) + await self.execute.asyn(cmd.format(int(boolean(autorotate)))) - def set_natural_rotation(self): - self.set_rotation(0) + @asyn.asyncf + async def set_natural_rotation(self): + await self.set_rotation.asyn(0) - def set_left_rotation(self): - self.set_rotation(1) + @asyn.asyncf + async def set_left_rotation(self): + await self.set_rotation.asyn(1) - def set_inverted_rotation(self): - self.set_rotation(2) + @asyn.asyncf + async def set_inverted_rotation(self): + await self.set_rotation.asyn(2) - def set_right_rotation(self): - self.set_rotation(3) + @asyn.asyncf + async def set_right_rotation(self): + await self.set_rotation.asyn(3) - def get_rotation(self): - output = self.execute('dumpsys input') + @asyn.asyncf + async def get_rotation(self): + output = await self.execute.asyn('dumpsys input') match = ANDROID_SCREEN_ROTATION_REGEX.search(output) if match: return int(match.group('rotation')) else: return None - def set_rotation(self, rotation): + @asyn.asyncf + async def set_rotation(self, rotation): if not 0 <= rotation <= 3: raise ValueError('Rotation value must be between 0 and 3') - self.set_auto_rotation(False) + await self.set_auto_rotation.asyn(False) cmd = 'settings put system user_rotation {}' - self.execute(cmd.format(rotation)) + await self.execute.asyn(cmd.format(rotation)) - def set_stay_on_never(self): - self.set_stay_on_mode(0) + @asyn.asyncf + async def set_stay_on_never(self): + await self.set_stay_on_mode.asyn(0) - def set_stay_on_while_powered(self): - self.set_stay_on_mode(7) + @asyn.asyncf + async def set_stay_on_while_powered(self): + await self.set_stay_on_mode.asyn(7) - def set_stay_on_mode(self, mode): + @asyn.asyncf + async def set_stay_on_mode(self, mode): if not 0 <= mode <= 7: raise ValueError('Screen stay on mode must be between 0 and 7') cmd = 'settings put global stay_on_while_plugged_in {}' - self.execute(cmd.format(mode)) + await self.execute.asyn(cmd.format(mode)) - def open_url(self, url, force_new=False): + @asyn.asyncf + async def open_url(self, url, force_new=False): """ Start a view activity by specifying an URL @@ -2102,10 +2348,11 @@ def open_url(self, url, force_new=False): cmd = cmd + ' -f {}'.format(INTENT_FLAGS['ACTIVITY_NEW_TASK'] | INTENT_FLAGS['ACTIVITY_CLEAR_TASK']) - self.execute(cmd.format(quote(url))) + await self.execute.asyn(cmd.format(quote(url))) - def homescreen(self): - self.execute('am start -a android.intent.action.MAIN -c android.intent.category.HOME') + @asyn.asyncf + async def homescreen(self): + await self.execute.asyn('am start -a android.intent.action.MAIN -c android.intent.category.HOME') def _resolve_paths(self): if self.working_directory is None: @@ -2114,15 +2361,16 @@ def _resolve_paths(self): if self.executables_directory is None: self.executables_directory = '/data/local/tmp/bin' - def _ensure_executables_directory_is_writable(self): + @asyn.asyncf + async def _ensure_executables_directory_is_writable(self): matched = [] - for entry in self.list_file_systems(): + for entry in await self.list_file_systems.asyn(): if self.executables_directory.rstrip('/').startswith(entry.mount_point): matched.append(entry) if matched: entry = sorted(matched, key=lambda x: len(x.mount_point))[-1] if 'rw' not in entry.options: - self.execute('mount -o rw,remount {} {}'.format(quote(entry.device), + await self.execute.asyn('mount -o rw,remount {} {}'.format(quote(entry.device), quote(entry.mount_point)), as_root=True) else: @@ -2726,8 +2974,12 @@ def __getattr__(self, attr): else: raise - def connect(self, timeout=30, check_boot_completed=True): - super(ChromeOsTarget, self).connect(timeout, check_boot_completed) + def connect(self, timeout=30, check_boot_completed=True, max_async=50): + super(ChromeOsTarget, self).connect( + timeout=timeout, + check_boot_completed=check_boot_completed, + max_async=max_async, + ) # Assume device supports android apps if container directory is present if self.supports_android is None: From 6eb24a60d801e79726c71b38026a57bf2f9e3cb8 Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Fri, 8 Apr 2022 15:20:33 +0100 Subject: [PATCH 08/11] target: Expose Target(max_async=50) parameter Allow the user to set a maximum number of conrruent connections used to dispatch non-blocking commands when using the async API. --- devlib/target.py | 38 +++++++++++++++++++++++++++++++------- doc/target.rst | 13 ++++++++----- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/devlib/target.py b/devlib/target.py index cd4990057..78930b3db 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -306,9 +306,11 @@ def __init__(self, load_default_modules=True, shell_prompt=DEFAULT_SHELL_PROMPT, conn_cls=None, - is_container=False + is_container=False, + max_async=50, ): self._async_pool = None + self._async_pool_size = None self._unused_conns = set() self._is_rooted = None @@ -352,7 +354,7 @@ def __init__(self, self.modules = merge_lists(*module_lists, duplicates='first') self._update_modules('early') if connect: - self.connect() + self.connect(max_async=max_async) def __getstate__(self): # tls_property will recreate the underlying value automatically upon @@ -363,12 +365,25 @@ def __getstate__(self): for k, v in inspect.getmembers(self.__class__) if isinstance(v, _BoundTLSProperty) } + ignored.update(( + '_async_pool', + '_unused_conns', + )) return { k: v for k, v in self.__dict__.items() if k not in ignored } + def __setstate__(self, dct): + self.__dict__ = dct + pool_size = self._async_pool_size + if pool_size is None: + self._async_pool = None + else: + self._async_pool = ThreadPoolExecutor(pool_size) + self._unused_conns = set() + # connection and initialization @asyn.asyncf @@ -433,6 +448,7 @@ def make_conn(_): max_conns = len(conns) self.logger.debug(f'Detected max number of async commands: {max_conns}') + self._async_pool_size = max_conns self._async_pool = ThreadPoolExecutor(max_conns) @asyn.asyncf @@ -1547,6 +1563,7 @@ def __init__(self, shell_prompt=DEFAULT_SHELL_PROMPT, conn_cls=SshConnection, is_container=False, + max_async=50, ): super(LinuxTarget, self).__init__(connection_settings=connection_settings, platform=platform, @@ -1557,7 +1574,8 @@ def __init__(self, load_default_modules=load_default_modules, shell_prompt=shell_prompt, conn_cls=conn_cls, - is_container=is_container) + is_container=is_container, + max_async=max_async) def wait_boot_complete(self, timeout=10): pass @@ -1752,6 +1770,7 @@ def __init__(self, conn_cls=AdbConnection, package_data_directory="/data/data", is_container=False, + max_async=50, ): super(AndroidTarget, self).__init__(connection_settings=connection_settings, platform=platform, @@ -1762,7 +1781,8 @@ def __init__(self, load_default_modules=load_default_modules, shell_prompt=shell_prompt, conn_cls=conn_cls, - is_container=is_container) + is_container=is_container, + max_async=max_async) self.package_data_directory = package_data_directory self._init_logcat_lock() @@ -2823,6 +2843,7 @@ def __init__(self, shell_prompt=DEFAULT_SHELL_PROMPT, conn_cls=LocalConnection, is_container=False, + max_async=50, ): super(LocalLinuxTarget, self).__init__(connection_settings=connection_settings, platform=platform, @@ -2833,7 +2854,8 @@ def __init__(self, load_default_modules=load_default_modules, shell_prompt=shell_prompt, conn_cls=conn_cls, - is_container=is_container) + is_container=is_container, + max_async=max_async) def _resolve_paths(self): if self.working_directory is None: @@ -2906,7 +2928,8 @@ def __init__(self, load_default_modules=True, shell_prompt=DEFAULT_SHELL_PROMPT, package_data_directory="/data/data", - is_container=False + is_container=False, + max_async=50, ): self.supports_android = None @@ -2932,7 +2955,8 @@ def __init__(self, load_default_modules=load_default_modules, shell_prompt=shell_prompt, conn_cls=SshConnection, - is_container=is_container) + is_container=is_container, + max_async=max_async) # We can't determine if the target supports android until connected to the linux host so # create unconditionally. diff --git a/doc/target.rst b/doc/target.rst index f17cfe4e1..17b2bbd3a 100644 --- a/doc/target.rst +++ b/doc/target.rst @@ -3,7 +3,7 @@ Target ====== -.. class:: Target(connection_settings=None, platform=None, working_directory=None, executables_directory=None, connect=True, modules=None, load_default_modules=True, shell_prompt=DEFAULT_SHELL_PROMPT, conn_cls=None) +.. class:: Target(connection_settings=None, platform=None, working_directory=None, executables_directory=None, connect=True, modules=None, load_default_modules=True, shell_prompt=DEFAULT_SHELL_PROMPT, conn_cls=None, max_async=50) :class:`~devlib.target.Target` is the primary interface to the remote device. All interactions with the device are performed via a @@ -76,6 +76,9 @@ Target :param conn_cls: This is the type of connection that will be used to communicate with the device. + :param max_async: Maximum number of opened connections to the target used to + issue non-blocking commands when using the async API. + .. attribute:: Target.core_names This is a list containing names of CPU cores on the target, in the order in @@ -606,7 +609,7 @@ Target Linux Target ------------ -.. class:: LinuxTarget(connection_settings=None, platform=None, working_directory=None, executables_directory=None, connect=True, modules=None, load_default_modules=True, shell_prompt=DEFAULT_SHELL_PROMPT, conn_cls=SshConnection, is_container=False,) +.. class:: LinuxTarget(connection_settings=None, platform=None, working_directory=None, executables_directory=None, connect=True, modules=None, load_default_modules=True, shell_prompt=DEFAULT_SHELL_PROMPT, conn_cls=SshConnection, is_container=False, max_async=50) :class:`LinuxTarget` is a subclass of :class:`~devlib.target.Target` with customisations specific to a device running linux. @@ -615,7 +618,7 @@ Linux Target Local Linux Target ------------------ -.. class:: LocalLinuxTarget(connection_settings=None, platform=None, working_directory=None, executables_directory=None, connect=True, modules=None, load_default_modules=True, shell_prompt=DEFAULT_SHELL_PROMPT, conn_cls=SshConnection, is_container=False,) +.. class:: LocalLinuxTarget(connection_settings=None, platform=None, working_directory=None, executables_directory=None, connect=True, modules=None, load_default_modules=True, shell_prompt=DEFAULT_SHELL_PROMPT, conn_cls=SshConnection, is_container=False, max_async=50) :class:`LocalLinuxTarget` is a subclass of :class:`~devlib.target.LinuxTarget` with customisations specific to using @@ -625,7 +628,7 @@ Local Linux Target Android Target --------------- -.. class:: AndroidTarget(connection_settings=None, platform=None, working_directory=None, executables_directory=None, connect=True, modules=None, load_default_modules=True, shell_prompt=DEFAULT_SHELL_PROMPT, conn_cls=AdbConnection, package_data_directory="/data/data") +.. class:: AndroidTarget(connection_settings=None, platform=None, working_directory=None, executables_directory=None, connect=True, modules=None, load_default_modules=True, shell_prompt=DEFAULT_SHELL_PROMPT, conn_cls=AdbConnection, package_data_directory="/data/data", max_async=50) :class:`AndroidTarget` is a subclass of :class:`~devlib.target.Target` with additional features specific to a device running Android. @@ -773,7 +776,7 @@ Android Target ChromeOS Target --------------- -.. class:: ChromeOsTarget(connection_settings=None, platform=None, working_directory=None, executables_directory=None, android_working_directory=None, android_executables_directory=None, connect=True, modules=None, load_default_modules=True, shell_prompt=DEFAULT_SHELL_PROMPT, package_data_directory="/data/data") +.. class:: ChromeOsTarget(connection_settings=None, platform=None, working_directory=None, executables_directory=None, android_working_directory=None, android_executables_directory=None, connect=True, modules=None, load_default_modules=True, shell_prompt=DEFAULT_SHELL_PROMPT, package_data_directory="/data/data", max_async=50) :class:`ChromeOsTarget` is a subclass of :class:`LinuxTarget` with additional features specific to a device running ChromeOS for example, From 5361ec1039055492d0c9ce20d82f9c3a97b61472 Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Mon, 15 Nov 2021 14:47:16 +0000 Subject: [PATCH 09/11] devlib: Use async Target API Make use of the new async API to speedup other parts of devlib. --- devlib/collector/ftrace.py | 23 +-- devlib/module/cgroups.py | 24 ++- devlib/module/cpufreq.py | 290 ++++++++++++++++++++++--------------- devlib/module/cpuidle.py | 79 ++++++---- 4 files changed, 257 insertions(+), 159 deletions(-) diff --git a/devlib/collector/ftrace.py b/devlib/collector/ftrace.py index 9849f4e21..292aa8271 100644 --- a/devlib/collector/ftrace.py +++ b/devlib/collector/ftrace.py @@ -28,6 +28,7 @@ from devlib.host import PACKAGE_BIN_DIRECTORY from devlib.exception import TargetStableError, HostError from devlib.utils.misc import check_output, which, memoized +from devlib.utils.asyn import asyncf TRACE_MARKER_START = 'TRACE_MARKER_START' @@ -243,7 +244,8 @@ def reset(self): self.target.write_value(self.function_profile_file, 0, verify=False) self._reset_needed = False - def start(self): + @asyncf + async def start(self): self.start_time = time.time() if self._reset_needed: self.reset() @@ -282,14 +284,17 @@ def start(self): self.target.cpuidle.perturb_cpus() # Enable kernel function profiling if self.functions and self.tracer is None: - self.target.execute('echo nop > {}'.format(self.current_tracer_file), - as_root=True) - self.target.execute('echo 0 > {}'.format(self.function_profile_file), - as_root=True) - self.target.execute('echo {} > {}'.format(self.function_string, self.ftrace_filter_file), - as_root=True) - self.target.execute('echo 1 > {}'.format(self.function_profile_file), - as_root=True) + target = self.target + await target.async_manager.concurrently( + execute.asyn('echo nop > {}'.format(self.current_tracer_file), + as_root=True), + execute.asyn('echo 0 > {}'.format(self.function_profile_file), + as_root=True), + execute.asyn('echo {} > {}'.format(self.function_string, self.ftrace_filter_file), + as_root=True), + execute.asyn('echo 1 > {}'.format(self.function_profile_file), + as_root=True), + ) def stop(self): diff --git a/devlib/module/cgroups.py b/devlib/module/cgroups.py index b3cdb1d5d..ece52c278 100644 --- a/devlib/module/cgroups.py +++ b/devlib/module/cgroups.py @@ -19,11 +19,13 @@ from shlex import quote import itertools import warnings +import asyncio from devlib.module import Module from devlib.exception import TargetStableError from devlib.utils.misc import list_to_ranges, isiterable from devlib.utils.types import boolean +from devlib.utils.asyn import asyncf class Controller(object): @@ -55,7 +57,8 @@ def __init__(self, kind, hid, clist): self.mount_point = None self._cgroups = {} - def mount(self, target, mount_root): + @asyncf + async def mount(self, target, mount_root): mounted = target.list_file_systems() if self.mount_name in [e.device for e in mounted]: @@ -68,16 +71,16 @@ def mount(self, target, mount_root): else: # Mount the controller if not already in use self.mount_point = target.path.join(mount_root, self.mount_name) - target.execute('mkdir -p {} 2>/dev/null'\ + await target.execute.asyn('mkdir -p {} 2>/dev/null'\ .format(self.mount_point), as_root=True) - target.execute('mount -t cgroup -o {} {} {}'\ + await target.execute.asyn('mount -t cgroup -o {} {} {}'\ .format(','.join(self.clist), self.mount_name, self.mount_point), as_root=True) # Check if this controller uses "noprefix" option - output = target.execute('mount | grep "{} "'.format(self.mount_name)) + output = await target.execute.asyn('mount | grep "{} "'.format(self.mount_name)) if 'noprefix' in output: self._noprefix = True # self.logger.debug('Controller %s using "noprefix" option', @@ -394,11 +397,12 @@ def __init__(self, target): # Initialize controllers self.logger.info('Available controllers:') self.controllers = {} - for ss in subsys: + + async def register_controller(ss): hid = ss.hierarchy controller = Controller(ss.name, hid, hierarchy[hid]) try: - controller.mount(self.target, self.cgroup_root) + await controller.mount.asyn(self.target, self.cgroup_root) except TargetStableError: message = 'Failed to mount "{}" controller' raise TargetStableError(message.format(controller.kind)) @@ -406,6 +410,14 @@ def __init__(self, target): controller.mount_point) self.controllers[ss.name] = controller + asyncio.run( + target.async_manager.map_concurrently( + register_controller, + subsys, + ) + ) + + def list_subsystems(self): subsystems = [] for line in self.target.execute('{} cat /proc/cgroups'\ diff --git a/devlib/module/cpufreq.py b/devlib/module/cpufreq.py index f559ef6cc..f147d4170 100644 --- a/devlib/module/cpufreq.py +++ b/devlib/module/cpufreq.py @@ -13,10 +13,12 @@ # limitations under the License. # from contextlib import contextmanager +from operator import itemgetter from devlib.module import Module from devlib.exception import TargetStableError from devlib.utils.misc import memoized +import devlib.utils.asyn as asyn # a dict of governor name and a list of it tunables that can't be read @@ -30,44 +32,52 @@ class CpufreqModule(Module): name = 'cpufreq' @staticmethod - def probe(target): - - # x86 with Intel P-State driver - if target.abi == 'x86_64': - path = '/sys/devices/system/cpu/intel_pstate' - if target.file_exists(path): - return True - - # Generic CPUFreq support (single policy) - path = '/sys/devices/system/cpu/cpufreq/policy0' - if target.file_exists(path): - return True - - # Generic CPUFreq support (per CPU policy) - path = '/sys/devices/system/cpu/cpu0/cpufreq' - return target.file_exists(path) + @asyn.asyncf + async def probe(target): + paths = [ + # x86 with Intel P-State driver + (target.abi == 'x86_64', '/sys/devices/system/cpu/intel_pstate'), + # Generic CPUFreq support (single policy) + (True, '/sys/devices/system/cpu/cpufreq/policy0'), + # Generic CPUFreq support (per CPU policy) + (True, '/sys/devices/system/cpu/cpu0/cpufreq'), + ] + paths = [ + path[1] for path in paths + if path[0] + ] + + exists = await target.async_manager.map_concurrently( + target.file_exists.asyn, + paths, + ) + + return any(exists.values()) def __init__(self, target): super(CpufreqModule, self).__init__(target) self._governor_tunables = {} @memoized - def list_governors(self, cpu): + @asyn.asyncf + async def list_governors(self, cpu): """Returns a list of governors supported by the cpu.""" if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_available_governors'.format(cpu) - output = self.target.read_value(sysfile) + output = await self.target.read_value.asyn(sysfile) return output.strip().split() - def get_governor(self, cpu): + @asyn.asyncf + async def get_governor(self, cpu): """Returns the governor currently set for the specified CPU.""" if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_governor'.format(cpu) - return self.target.read_value(sysfile) + return await self.target.read_value.asyn(sysfile) - def set_governor(self, cpu, governor, **kwargs): + @asyn.asyncf + async def set_governor(self, cpu, governor, **kwargs): """ Set the governor for the specified CPU. See https://www.kernel.org/doc/Documentation/cpu-freq/governors.txt @@ -90,15 +100,15 @@ def set_governor(self, cpu, governor, **kwargs): """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - supported = self.list_governors(cpu) + supported = await self.list_governors.asyn(cpu) if governor not in supported: raise TargetStableError('Governor {} not supported for cpu {}'.format(governor, cpu)) sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_governor'.format(cpu) - self.target.write_value(sysfile, governor) - self.set_governor_tunables(cpu, governor, **kwargs) + await self.target.write_value.asyn(sysfile, governor) + return await self.set_governor_tunables.asyn(cpu, governor, **kwargs) - @contextmanager - def use_governor(self, governor, cpus=None, **kwargs): + @asyn.asynccontextmanager + async def use_governor(self, governor, cpus=None, **kwargs): """ Use a given governor, then restore previous governor(s) @@ -111,66 +121,97 @@ def use_governor(self, governor, cpus=None, **kwargs): :Keyword Arguments: Governor tunables, See :meth:`set_governor_tunables` """ if not cpus: - cpus = self.target.list_online_cpus() - - # Setting a governor & tunables for a cpu will set them for all cpus - # in the same clock domain, so only manipulating one cpu per domain - # is enough - domains = set(self.get_affected_cpus(cpu)[0] for cpu in cpus) - prev_governors = {cpu : (self.get_governor(cpu), self.get_governor_tunables(cpu)) - for cpu in domains} - - # Special case for userspace, frequency is not seen as a tunable - userspace_freqs = {} - for cpu, (prev_gov, _) in prev_governors.items(): - if prev_gov == "userspace": - userspace_freqs[cpu] = self.get_frequency(cpu) - - for cpu in domains: - self.set_governor(cpu, governor, **kwargs) + cpus = await self.target.list_online_cpus.asyn() + + async def get_cpu_info(cpu): + return await self.target.async_manager.concurrently(( + self.get_affected_cpus.asyn(cpu), + self.get_governor.asyn(cpu), + self.get_governor_tunables.asyn(cpu), + # We won't always use the frequency, but it's much quicker to + # do concurrently anyway so do it now + self.get_frequency.asyn(cpu), + )) + + cpus_infos = await self.target.async_manager.map_concurrently(get_cpu_info, cpus) + + # Setting a governor & tunables for a cpu will set them for all cpus in + # the same cpufreq policy, so only manipulating one cpu per domain is + # enough + domains = set( + info[0][0] + for info in cpus_infos.values() + ) + + await self.target.async_manager.concurrently( + self.set_governor.asyn(cpu, governor, **kwargs) + for cpu in domains + ) try: yield - finally: - for cpu, (prev_gov, tunables) in prev_governors.items(): - self.set_governor(cpu, prev_gov, **tunables) + async def set_gov(cpu): + domain, prev_gov, tunables, freq = cpus_infos[cpu] + await self.set_governor.asyn(cpu, prev_gov, **tunables) + # Special case for userspace, frequency is not seen as a tunable if prev_gov == "userspace": - self.set_frequency(cpu, userspace_freqs[cpu]) + await self.set_frequency.asyn(cpu, freq) + + await self.target.async_manager.concurrently( + set_gov(cpu) + for cpu in domains + ) - def list_governor_tunables(self, cpu): + @asyn.asyncf + async def list_governor_tunables(self, cpu): """Returns a list of tunables available for the governor on the specified CPU.""" if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - governor = self.get_governor(cpu) + governor = await self.get_governor.asyn(cpu) if governor not in self._governor_tunables: try: tunables_path = '/sys/devices/system/cpu/{}/cpufreq/{}'.format(cpu, governor) - self._governor_tunables[governor] = self.target.list_directory(tunables_path) + self._governor_tunables[governor] = await self.target.list_directory.asyn(tunables_path) except TargetStableError: # probably an older kernel try: tunables_path = '/sys/devices/system/cpu/cpufreq/{}'.format(governor) - self._governor_tunables[governor] = self.target.list_directory(tunables_path) + self._governor_tunables[governor] = await self.target.list_directory.asyn(tunables_path) except TargetStableError: # governor does not support tunables self._governor_tunables[governor] = [] return self._governor_tunables[governor] - def get_governor_tunables(self, cpu): + @asyn.asyncf + async def get_governor_tunables(self, cpu): if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - governor = self.get_governor(cpu) + governor, tunable_list = await self.target.async_manager.concurrently(( + self.get_governor.asyn(cpu), + self.list_governor_tunables.asyn(cpu) + )) + + write_only = set(WRITE_ONLY_TUNABLES.get(governor, [])) + tunable_list = [ + tunable + for tunable in tunable_list + if tunable not in write_only + ] + tunables = {} - for tunable in self.list_governor_tunables(cpu): - if tunable not in WRITE_ONLY_TUNABLES.get(governor, []): - try: - path = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) - tunables[tunable] = self.target.read_value(path) - except TargetStableError: # May be an older kernel - path = '/sys/devices/system/cpu/cpufreq/{}/{}'.format(governor, tunable) - tunables[tunable] = self.target.read_value(path) + async def get_tunable(tunable): + try: + path = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) + x = await self.target.read_value.asyn(path) + except TargetStableError: # May be an older kernel + path = '/sys/devices/system/cpu/cpufreq/{}/{}'.format(governor, tunable) + x = await self.target.read_value.asyn(path) + return x + + tunables = await self.target.async_manager.map_concurrently(get_tunable, tunable_list) return tunables - def set_governor_tunables(self, cpu, governor=None, **kwargs): + @asyn.asyncf + async def set_governor_tunables(self, cpu, governor=None, **kwargs): """ Set tunables for the specified governor. Tunables should be specified as keyword arguments. Which tunables and values are valid depends on the @@ -191,34 +232,35 @@ def set_governor_tunables(self, cpu, governor=None, **kwargs): if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) if governor is None: - governor = self.get_governor(cpu) - valid_tunables = self.list_governor_tunables(cpu) + governor = await self.get_governor.asyn(cpu) + valid_tunables = await self.list_governor_tunables.asyn(cpu) for tunable, value in kwargs.items(): if tunable in valid_tunables: path = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) try: - self.target.write_value(path, value) + await self.target.write_value.asyn(path, value) except TargetStableError: - if self.target.file_exists(path): + if await self.target.file_exists.asyn(path): # File exists but we did something wrong raise # Expected file doesn't exist, try older sysfs layout. path = '/sys/devices/system/cpu/cpufreq/{}/{}'.format(governor, tunable) - self.target.write_value(path, value) + await self.target.write_value.asyn(path, value) else: message = 'Unexpected tunable {} for governor {} on {}.\n'.format(tunable, governor, cpu) message += 'Available tunables are: {}'.format(valid_tunables) raise TargetStableError(message) @memoized - def list_frequencies(self, cpu): + @asyn.asyncf + async def list_frequencies(self, cpu): """Returns a sorted list of frequencies supported by the cpu or an empty list if not could be found.""" if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) try: cmd = 'cat /sys/devices/system/cpu/{}/cpufreq/scaling_available_frequencies'.format(cpu) - output = self.target.execute(cmd) + output = await self.target.execute.asyn(cmd) available_frequencies = list(map(int, output.strip().split())) # pylint: disable=E1103 except TargetStableError: # On some devices scaling_frequencies is not generated. @@ -226,7 +268,7 @@ def list_frequencies(self, cpu): # Fall back to parsing stats/time_in_state path = '/sys/devices/system/cpu/{}/cpufreq/stats/time_in_state'.format(cpu) try: - out_iter = iter(self.target.read_value(path).split()) + out_iter = (await self.target.read_value.asyn(path)).split() except TargetStableError: if not self.target.file_exists(path): # Probably intel_pstate. Can't get available freqs. @@ -254,7 +296,8 @@ def get_min_available_frequency(self, cpu): freqs = self.list_frequencies(cpu) return min(freqs) if freqs else None - def get_min_frequency(self, cpu): + @asyn.asyncf + async def get_min_frequency(self, cpu): """ Returns the min frequency currently set for the specified CPU. @@ -268,9 +311,10 @@ def get_min_frequency(self, cpu): if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_min_freq'.format(cpu) - return self.target.read_int(sysfile) + return await self.target.read_int.asyn(sysfile) - def set_min_frequency(self, cpu, frequency, exact=True): + @asyn.asyncf + async def set_min_frequency(self, cpu, frequency, exact=True): """ Set's the minimum value for CPU frequency. Actual frequency will depend on the Governor used and may vary during execution. The value should be @@ -289,7 +333,7 @@ def set_min_frequency(self, cpu, frequency, exact=True): """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - available_frequencies = self.list_frequencies(cpu) + available_frequencies = await self.list_frequencies.asyn(cpu) try: value = int(frequency) if exact and available_frequencies and value not in available_frequencies: @@ -297,11 +341,12 @@ def set_min_frequency(self, cpu, frequency, exact=True): value, available_frequencies)) sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_min_freq'.format(cpu) - self.target.write_value(sysfile, value) + await self.target.write_value.asyn(sysfile, value) except ValueError: raise ValueError('Frequency must be an integer; got: "{}"'.format(frequency)) - def get_frequency(self, cpu, cpuinfo=False): + @asyn.asyncf + async def get_frequency(self, cpu, cpuinfo=False): """ Returns the current frequency currently set for the specified CPU. @@ -321,9 +366,10 @@ def get_frequency(self, cpu, cpuinfo=False): sysfile = '/sys/devices/system/cpu/{}/cpufreq/{}'.format( cpu, 'cpuinfo_cur_freq' if cpuinfo else 'scaling_cur_freq') - return self.target.read_int(sysfile) + return await self.target.read_int.asyn(sysfile) - def set_frequency(self, cpu, frequency, exact=True): + @asyn.asyncf + async def set_frequency(self, cpu, frequency, exact=True): """ Set's the minimum value for CPU frequency. Actual frequency will depend on the Governor used and may vary during execution. The value should be @@ -347,23 +393,24 @@ def set_frequency(self, cpu, frequency, exact=True): try: value = int(frequency) if exact: - available_frequencies = self.list_frequencies(cpu) + available_frequencies = await self.list_frequencies.asyn(cpu) if available_frequencies and value not in available_frequencies: raise TargetStableError('Can\'t set {} frequency to {}\nmust be in {}'.format(cpu, value, available_frequencies)) - if self.get_governor(cpu) != 'userspace': + if await self.get_governor.asyn(cpu) != 'userspace': raise TargetStableError('Can\'t set {} frequency; governor must be "userspace"'.format(cpu)) sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_setspeed'.format(cpu) - self.target.write_value(sysfile, value, verify=False) - cpuinfo = self.get_frequency(cpu, cpuinfo=True) + await self.target.write_value.asyn(sysfile, value, verify=False) + cpuinfo = await self.get_frequency.asyn(cpu, cpuinfo=True) if cpuinfo != value: self.logger.warning( 'The cpufreq value has not been applied properly cpuinfo={} request={}'.format(cpuinfo, value)) except ValueError: raise ValueError('Frequency must be an integer; got: "{}"'.format(frequency)) - def get_max_frequency(self, cpu): + @asyn.asyncf + async def get_max_frequency(self, cpu): """ Returns the max frequency currently set for the specified CPU. @@ -376,9 +423,10 @@ def get_max_frequency(self, cpu): if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_max_freq'.format(cpu) - return self.target.read_int(sysfile) + return await self.target.read_int.asyn(sysfile) - def set_max_frequency(self, cpu, frequency, exact=True): + @asyn.asyncf + async def set_max_frequency(self, cpu, frequency, exact=True): """ Set's the minimum value for CPU frequency. Actual frequency will depend on the Governor used and may vary during execution. The value should be @@ -397,7 +445,7 @@ def set_max_frequency(self, cpu, frequency, exact=True): """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - available_frequencies = self.list_frequencies(cpu) + available_frequencies = await self.list_frequencies.asyn(cpu) try: value = int(frequency) if exact and available_frequencies and value not in available_frequencies: @@ -405,45 +453,53 @@ def set_max_frequency(self, cpu, frequency, exact=True): value, available_frequencies)) sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_max_freq'.format(cpu) - self.target.write_value(sysfile, value) + await self.target.write_value.asyn(sysfile, value) except ValueError: raise ValueError('Frequency must be an integer; got: "{}"'.format(frequency)) - def set_governor_for_cpus(self, cpus, governor, **kwargs): + @asyn.asyncf + async def set_governor_for_cpus(self, cpus, governor, **kwargs): """ Set the governor for the specified list of CPUs. See https://www.kernel.org/doc/Documentation/cpu-freq/governors.txt :param cpus: The list of CPU for which the governor is to be set. """ - for cpu in cpus: + await self.target.async_manager.map_concurrently( self.set_governor(cpu, governor, **kwargs) + for cpu in sorted(set(cpus)) + ) - def set_frequency_for_cpus(self, cpus, freq, exact=False): + @asyn.asyncf + async def set_frequency_for_cpus(self, cpus, freq, exact=False): """ Set the frequency for the specified list of CPUs. See https://www.kernel.org/doc/Documentation/cpu-freq/governors.txt :param cpus: The list of CPU for which the frequency has to be set. """ - for cpu in cpus: + await self.target.async_manager.map_concurrently( self.set_frequency(cpu, freq, exact) + for cpu in sorted(set(cpus)) + ) - def set_all_frequencies(self, freq): + @asyn.asyncf + async def set_all_frequencies(self, freq): """ Set the specified (minimum) frequency for all the (online) CPUs """ # pylint: disable=protected-access - return self.target._execute_util( + return await self.target._execute_util.asyn( 'cpufreq_set_all_frequencies {}'.format(freq), as_root=True) - def get_all_frequencies(self): + @asyn.asyncf + async def get_all_frequencies(self): """ Get the current frequency for all the (online) CPUs """ # pylint: disable=protected-access - output = self.target._execute_util( + output = await self.target._execute_util.asyn( 'cpufreq_get_all_frequencies', as_root=True) frequencies = {} for x in output.splitlines(): @@ -453,32 +509,34 @@ def get_all_frequencies(self): frequencies[kv[0]] = kv[1] return frequencies - def set_all_governors(self, governor): + @asyn.asyncf + async def set_all_governors(self, governor): """ Set the specified governor for all the (online) CPUs """ try: # pylint: disable=protected-access - return self.target._execute_util( + return await self.target._execute_util.asyn( 'cpufreq_set_all_governors {}'.format(governor), as_root=True) except TargetStableError as e: if ("echo: I/O error" in str(e) or "write error: Invalid argument" in str(e)): - cpus_unsupported = [c for c in self.target.list_online_cpus() - if governor not in self.list_governors(c)] + cpus_unsupported = [c for c in await self.target.list_online_cpus.asyn() + if governor not in await self.list_governors.asyn(c)] raise TargetStableError("Governor {} unsupported for CPUs {}".format( governor, cpus_unsupported)) else: raise - def get_all_governors(self): + @asyn.asyncf + async def get_all_governors(self): """ Get the current governor for all the (online) CPUs """ # pylint: disable=protected-access - output = self.target._execute_util( + output = await self.target._execute_util.asyn( 'cpufreq_get_all_governors', as_root=True) governors = {} for x in output.splitlines(): @@ -488,14 +546,16 @@ def get_all_governors(self): governors[kv[0]] = kv[1] return governors - def trace_frequencies(self): + @asyn.asyncf + async def trace_frequencies(self): """ Report current frequencies on trace file """ # pylint: disable=protected-access - return self.target._execute_util('cpufreq_trace_all_frequencies', as_root=True) + return await self.target._execute_util.asyn('cpufreq_trace_all_frequencies', as_root=True) - def get_affected_cpus(self, cpu): + @asyn.asyncf + async def get_affected_cpus(self, cpu): """ Get the online CPUs that share a frequency domain with the given CPU """ @@ -504,10 +564,12 @@ def get_affected_cpus(self, cpu): sysfile = '/sys/devices/system/cpu/{}/cpufreq/affected_cpus'.format(cpu) - return [int(c) for c in self.target.read_value(sysfile).split()] + content = await self.target.read_value.asyn(sysfile) + return [int(c) for c in content.split()] - @memoized - def get_related_cpus(self, cpu): + @asyn.asyncf + @asyn.memoized_method + async def get_related_cpus(self, cpu): """ Get the CPUs that share a frequency domain with the given CPU """ @@ -516,10 +578,11 @@ def get_related_cpus(self, cpu): sysfile = '/sys/devices/system/cpu/{}/cpufreq/related_cpus'.format(cpu) - return [int(c) for c in self.target.read_value(sysfile).split()] + return [int(c) for c in (await self.target.read_value.asyn(sysfile)).split()] - @memoized - def get_driver(self, cpu): + @asyn.asyncf + @asyn.memoized_method + async def get_driver(self, cpu): """ Get the name of the driver used by this cpufreq policy. """ @@ -528,15 +591,16 @@ def get_driver(self, cpu): sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_driver'.format(cpu) - return self.target.read_value(sysfile).strip() + return (await self.target.read_value.asyn(sysfile)).strip() - def iter_domains(self): + @asyn.asyncf + async def iter_domains(self): """ Iterate over the frequency domains in the system """ cpus = set(range(self.target.number_of_cpus)) while cpus: cpu = next(iter(cpus)) # pylint: disable=stop-iteration-return - domain = self.target.cpufreq.get_related_cpus(cpu) + domain = await self.target.cpufreq.get_related_cpus.asyn(cpu) yield domain cpus = cpus.difference(domain) diff --git a/devlib/module/cpuidle.py b/devlib/module/cpuidle.py index 7b1704c73..a7d0fef64 100644 --- a/devlib/module/cpuidle.py +++ b/devlib/module/cpuidle.py @@ -22,6 +22,7 @@ from devlib.exception import TargetStableError from devlib.utils.types import integer, boolean from devlib.utils.misc import memoized +import devlib.utils.asyn as asyn class CpuidleState(object): @@ -59,19 +60,23 @@ def __init__(self, target, index, path, name, desc, power, latency, residency): self.id = self.target.path.basename(self.path) self.cpu = self.target.path.basename(self.target.path.dirname(path)) - def enable(self): - self.set('disable', 0) + @asyn.asyncf + async def enable(self): + await self.set.asyn('disable', 0) - def disable(self): - self.set('disable', 1) + @asyn.asyncf + async def disable(self): + await self.set.asyn('disable', 1) - def get(self, prop): + @asyn.asyncf + async def get(self, prop): property_path = self.target.path.join(self.path, prop) - return self.target.read_value(property_path) + return await self.target.read_value.asyn(property_path) - def set(self, prop, value): + @asyn.asyncf + async def set(self, prop, value): property_path = self.target.path.join(self.path, prop) - self.target.write_value(property_path, value) + await self.target.write_value.asyn(property_path, value) def __eq__(self, other): if isinstance(other, CpuidleState): @@ -96,8 +101,9 @@ class Cpuidle(Module): root_path = '/sys/devices/system/cpu/cpuidle' @staticmethod - def probe(target): - return target.file_exists(Cpuidle.root_path) + @asyn.asyncf + async def probe(target): + return await target.file_exists.asyn(Cpuidle.root_path) def __init__(self, target): super(Cpuidle, self).__init__(target) @@ -148,29 +154,39 @@ def get_state(self, state, cpu=0): return s raise ValueError('Cpuidle state {} does not exist'.format(state)) - def enable(self, state, cpu=0): - self.get_state(state, cpu).enable() - - def disable(self, state, cpu=0): - self.get_state(state, cpu).disable() - - def enable_all(self, cpu=0): - for state in self.get_states(cpu): - state.enable() - - def disable_all(self, cpu=0): - for state in self.get_states(cpu): - state.disable() - - def perturb_cpus(self): + @asyn.asyncf + async def enable(self, state, cpu=0): + await self.get_state(state, cpu).enable.asyn() + + @asyn.asyncf + async def disable(self, state, cpu=0): + await self.get_state(state, cpu).disable.asyn() + + @asyn.asyncf + async def enable_all(self, cpu=0): + await self.target.async_manager.concurrently( + state.enable.asyn() + for state in self.get_states(cpu) + ) + + @asyn.asyncf + async def disable_all(self, cpu=0): + await self.target.async_manager.concurrently( + state.disable.asyn() + for state in self.get_states(cpu) + ) + + @asyn.asyncf + async def perturb_cpus(self): """ Momentarily wake each CPU. Ensures cpu_idle events in trace file. """ # pylint: disable=protected-access - self.target._execute_util('cpuidle_wake_all_cpus') + await self.target._execute_util.asyn('cpuidle_wake_all_cpus') - def get_driver(self): - return self.target.read_value(self.target.path.join(self.root_path, 'current_driver')) + @asyn.asyncf + async def get_driver(self): + return await self.target.read_value.asyn(self.target.path.join(self.root_path, 'current_driver')) @memoized def list_governors(self): @@ -179,12 +195,13 @@ def list_governors(self): output = self.target.read_value(sysfile) return output.strip().split() - def get_governor(self): + @asyn.asyncf + async def get_governor(self): """Returns the currently selected idle governor.""" path = self.target.path.join(self.root_path, 'current_governor_ro') - if not self.target.file_exists(path): + if not await self.target.file_exists.asyn(path): path = self.target.path.join(self.root_path, 'current_governor') - return self.target.read_value(path) + return await self.target.read_value.asyn(path) def set_governor(self, governor): """ From 1a6dcb8f5d4f5853a3c38858690960b403606213 Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Tue, 26 Jul 2022 14:42:33 +0100 Subject: [PATCH 10/11] module/cpufreq: Fix async use_governor() use_governor() was trying to set concurrently both per-cpu and global tunables for each governor, which lead to a write conflict. Split the work into the per-governor global tunables and the per-cpu tunables, and do all that in concurrently. Each task is therefore responsible of a distinct set of files and all is well. Also remove @memoized on async functions. It will be reintroduced in a later commit when there is a safe alternative for async functions. --- devlib/module/cpufreq.py | 112 ++++++++++++++++++++++++++------------- 1 file changed, 76 insertions(+), 36 deletions(-) diff --git a/devlib/module/cpufreq.py b/devlib/module/cpufreq.py index f147d4170..c66ae3831 100644 --- a/devlib/module/cpufreq.py +++ b/devlib/module/cpufreq.py @@ -58,7 +58,6 @@ def __init__(self, target): super(CpufreqModule, self).__init__(target) self._governor_tunables = {} - @memoized @asyn.asyncf async def list_governors(self, cpu): """Returns a list of governors supported by the cpu.""" @@ -105,7 +104,7 @@ async def set_governor(self, cpu, governor, **kwargs): raise TargetStableError('Governor {} not supported for cpu {}'.format(governor, cpu)) sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_governor'.format(cpu) await self.target.write_value.asyn(sysfile, governor) - return await self.set_governor_tunables.asyn(cpu, governor, **kwargs) + await self.set_governor_tunables.asyn(cpu, governor, **kwargs) @asyn.asynccontextmanager async def use_governor(self, governor, cpus=None, **kwargs): @@ -151,44 +150,83 @@ async def get_cpu_info(cpu): try: yield finally: - async def set_gov(cpu): + async def set_per_cpu_tunables(cpu): domain, prev_gov, tunables, freq = cpus_infos[cpu] - await self.set_governor.asyn(cpu, prev_gov, **tunables) + # Per-cpu tunables are safe to set concurrently + await self.set_governor_tunables.asyn(cpu, prev_gov, per_cpu=True, **tunables) # Special case for userspace, frequency is not seen as a tunable if prev_gov == "userspace": await self.set_frequency.asyn(cpu, freq) + per_cpu_tunables = self.target.async_manager.concurrently( + set_per_cpu_tunables(cpu) + for cpu in domains + ) + + # Non-per-cpu tunables have to be set one after the other, for each + # governor that we had to deal with. + global_tunables = { + prev_gov: (cpu, tunables) + for cpu, (domain, prev_gov, tunables, freq) in cpus_infos.items() + } + + global_tunables = self.target.async_manager.concurrently( + self.set_governor_tunables.asyn(cpu, gov, per_cpu=False, **tunables) + for gov, (cpu, tunables) in global_tunables.items() + ) + + # Set the governor first await self.target.async_manager.concurrently( - set_gov(cpu) + self.set_governor.asyn(cpu, cpus_infos[cpu][1]) for cpu in domains ) + # And then set all the tunables concurrently. Each task has a + # specific and non-overlapping set of file to write. + await self.target.async_manager.concurrently( + (per_cpu_tunables, global_tunables) + ) @asyn.asyncf - async def list_governor_tunables(self, cpu): - """Returns a list of tunables available for the governor on the specified CPU.""" + async def _list_governor_tunables(self, cpu, governor=None): if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - governor = await self.get_governor.asyn(cpu) - if governor not in self._governor_tunables: - try: - tunables_path = '/sys/devices/system/cpu/{}/cpufreq/{}'.format(cpu, governor) - self._governor_tunables[governor] = await self.target.list_directory.asyn(tunables_path) - except TargetStableError: # probably an older kernel + + if governor is None: + governor = await self.get_governor.asyn(cpu) + + try: + return self._governor_tunables[governor] + except KeyError: + for per_cpu, path in ( + (True, '/sys/devices/system/cpu/{}/cpufreq/{}'.format(cpu, governor)), + # On old kernels + (False, '/sys/devices/system/cpu/cpufreq/{}'.format(governor)), + ): try: - tunables_path = '/sys/devices/system/cpu/cpufreq/{}'.format(governor) - self._governor_tunables[governor] = await self.target.list_directory.asyn(tunables_path) - except TargetStableError: # governor does not support tunables - self._governor_tunables[governor] = [] - return self._governor_tunables[governor] + tunables = await self.target.list_directory.asyn(path) + except TargetStableError: + continue + else: + break + else: + per_cpu = False + tunables = [] + + data = (governor, per_cpu, tunables) + self._governor_tunables[governor] = data + return data + + @asyn.asyncf + async def list_governor_tunables(self, cpu): + """Returns a list of tunables available for the governor on the specified CPU.""" + _, _, tunables = await self._list_governor_tunables.asyn(cpu) + return tunables @asyn.asyncf async def get_governor_tunables(self, cpu): if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - governor, tunable_list = await self.target.async_manager.concurrently(( - self.get_governor.asyn(cpu), - self.list_governor_tunables.asyn(cpu) - )) + governor, _, tunable_list = await self._list_governor_tunables.asyn(cpu) write_only = set(WRITE_ONLY_TUNABLES.get(governor, [])) tunable_list = [ @@ -211,7 +249,7 @@ async def get_tunable(tunable): return tunables @asyn.asyncf - async def set_governor_tunables(self, cpu, governor=None, **kwargs): + async def set_governor_tunables(self, cpu, governor=None, per_cpu=None, **kwargs): """ Set tunables for the specified governor. Tunables should be specified as keyword arguments. Which tunables and values are valid depends on the @@ -220,6 +258,9 @@ async def set_governor_tunables(self, cpu, governor=None, **kwargs): :param cpu: The cpu for which the governor will be set. ``int`` or full cpu name as it appears in sysfs, e.g. ``cpu0``. :param governor: The name of the governor. Must be all lower case. + :param per_cpu: If ``None``, both per-cpu and global governor tunables + will be set. If ``True``, only per-CPU tunables will be set and if + ``False``, only global tunables will be set. The rest should be keyword parameters mapping tunable name onto the value to be set for it. @@ -229,29 +270,28 @@ async def set_governor_tunables(self, cpu, governor=None, **kwargs): tunable. """ + if not kwargs: + return if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - if governor is None: - governor = await self.get_governor.asyn(cpu) - valid_tunables = await self.list_governor_tunables.asyn(cpu) + + governor, gov_per_cpu, valid_tunables = await self._list_governor_tunables.asyn(cpu, governor=governor) for tunable, value in kwargs.items(): if tunable in valid_tunables: - path = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) - try: - await self.target.write_value.asyn(path, value) - except TargetStableError: - if await self.target.file_exists.asyn(path): - # File exists but we did something wrong - raise - # Expected file doesn't exist, try older sysfs layout. + if per_cpu is not None and gov_per_cpu != per_cpu: + pass + + if gov_per_cpu: + path = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) + else: path = '/sys/devices/system/cpu/cpufreq/{}/{}'.format(governor, tunable) - await self.target.write_value.asyn(path, value) + + await self.target.write_value.asyn(path, value) else: message = 'Unexpected tunable {} for governor {} on {}.\n'.format(tunable, governor, cpu) message += 'Available tunables are: {}'.format(valid_tunables) raise TargetStableError(message) - @memoized @asyn.asyncf async def list_frequencies(self, cpu): """Returns a sorted list of frequencies supported by the cpu or an empty list From 843e954bdb496dd500f09b6723d4516f57c0d0fc Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Tue, 26 Jul 2022 16:27:13 +0100 Subject: [PATCH 11/11] utils/asyn: Add memoize_method() decorator Add a memoize_method decorator that works for async methods. It will not leak memory since the memoization cache is held in the instance __dict__, and it does not rely on hacks to hash unhashable data. --- devlib/module/cpufreq.py | 2 + devlib/utils/asyn.py | 81 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/devlib/module/cpufreq.py b/devlib/module/cpufreq.py index c66ae3831..0910c436e 100644 --- a/devlib/module/cpufreq.py +++ b/devlib/module/cpufreq.py @@ -59,6 +59,7 @@ def __init__(self, target): self._governor_tunables = {} @asyn.asyncf + @asyn.memoized_method async def list_governors(self, cpu): """Returns a list of governors supported by the cpu.""" if isinstance(cpu, int): @@ -293,6 +294,7 @@ async def set_governor_tunables(self, cpu, governor=None, per_cpu=None, **kwargs raise TargetStableError(message) @asyn.asyncf + @asyn.memoized_method async def list_frequencies(self, cpu): """Returns a sorted list of frequencies supported by the cpu or an empty list if not could be found.""" diff --git a/devlib/utils/asyn.py b/devlib/utils/asyn.py index ddda1158f..a993077d4 100644 --- a/devlib/utils/asyn.py +++ b/devlib/utils/asyn.py @@ -25,6 +25,7 @@ import contextlib import pathlib import os.path +import inspect # Allow nesting asyncio loops, which is necessary for: # * Being able to call the blocking variant of a function from an async @@ -211,6 +212,86 @@ def __getattr__(self, attr): return getattr(self.asyn, attr) +class memoized_method: + """ + Decorator to memmoize a method. + + It works for: + + * async methods (coroutine functions) + * non-async methods + * method already decorated with :func:`devlib.asyn.asyncf`. + + .. note:: This decorator does not rely on hacks to hash unhashable data. If + such input is required, it will either have to be coerced to a hashable + first (e.g. converting a list to a tuple), or the code of + :func:`devlib.asyn.memoized_method` will have to be updated to do so. + """ + def __init__(self, f): + memo = self + + sig = inspect.signature(f) + + def bind(self, *args, **kwargs): + bound = sig.bind(self, *args, **kwargs) + bound.apply_defaults() + key = (bound.args[1:], tuple(sorted(bound.kwargs.items()))) + + return (key, bound.args, bound.kwargs) + + def get_cache(self): + try: + cache = self.__dict__[memo.name] + except KeyError: + cache = {} + self.__dict__[memo.name] = cache + return cache + + + if inspect.iscoroutinefunction(f): + @functools.wraps(f) + async def wrapper(self, *args, **kwargs): + cache = get_cache(self) + key, args, kwargs = bind(self, *args, **kwargs) + try: + return cache[key] + except KeyError: + x = await f(*args, **kwargs) + cache[key] = x + return x + else: + @functools.wraps(f) + def wrapper(self, *args, **kwargs): + cache = get_cache(self) + key, args, kwargs = bind(self, *args, **kwargs) + try: + return cache[key] + except KeyError: + x = f(*args, **kwargs) + cache[key] = x + return x + + + self.f = wrapper + self._name = f.__name__ + + @property + def name(self): + return '__memoization_cache_of_' + self._name + + def __call__(self, *args, **kwargs): + return self.f(*args, **kwargs) + + def __get__(self, obj, owner=None): + return self.f.__get__(obj, owner) + + def __set__(self, obj, value): + raise RuntimeError("Cannot monkey-patch a memoized function") + + def __set_name__(self, owner, name): + self.name = name + + def asyncf(f): """ Decorator used to turn a coroutine into a blocking function, with an