diff --git a/pymobiledevice3/cli/webinspector.py b/pymobiledevice3/cli/webinspector.py index a374e36cb..b96195404 100644 --- a/pymobiledevice3/cli/webinspector.py +++ b/pymobiledevice3/cli/webinspector.py @@ -27,11 +27,12 @@ from pymobiledevice3.exceptions import InspectorEvaluateError, LaunchingApplicationError, \ RemoteAutomationNotEnabledError, WebInspectorNotEnabledError, WirError from pymobiledevice3.lockdown import LockdownClient, create_using_usbmux +from pymobiledevice3.lockdown_service_provider import LockdownServiceProvider from pymobiledevice3.osu.os_utils import get_os_utils from pymobiledevice3.services.web_protocol.cdp_server import app from pymobiledevice3.services.web_protocol.driver import By, Cookie, WebDriver from pymobiledevice3.services.web_protocol.inspector_session import InspectorSession -from pymobiledevice3.services.webinspector import SAFARI, Page, WebinspectorService +from pymobiledevice3.services.webinspector import SAFARI, ApplicationPage, WebinspectorService SCRIPT = ''' function inspectedPage_evalResult_getCompletions(primitiveType) {{ @@ -212,9 +213,11 @@ def shell(service_provider: LockdownClient, timeout): @webinspector.command(cls=Command) @click.option('-t', '--timeout', default=3, show_default=True, type=float) @click.option('--automation', is_flag=True, help='Use remote automation') +@click.option('--no-open-safari', is_flag=True, help='Avoid opening the Safari app') @click.argument('url', required=False, default='') @catch_errors -def js_shell(service_provider: LockdownClient, timeout, automation, url): +def js_shell(service_provider: LockdownServiceProvider, timeout: float, automation: bool, no_open_safari: bool, + url: str) -> None: """ Create a javascript shell. This interpreter runs on your local machine, but evaluates each expression on the remote @@ -229,7 +232,7 @@ def js_shell(service_provider: LockdownClient, timeout, automation, url): """ js_shell_class = AutomationJsShell if automation else InspectorJsShell - asyncio.run(run_js_shell(js_shell_class, service_provider, timeout, url)) + asyncio.run(run_js_shell(js_shell_class, service_provider, timeout, url, not no_open_safari)) udid = '' @@ -297,7 +300,7 @@ def get_completions( class JsShell(ABC): - def __init__(self): + def __init__(self) -> None: super().__init__() self.prompt_session = PromptSession(lexer=PygmentsLexer(lexers.JavascriptLexer), auto_suggest=AutoSuggestFromHistory(), @@ -307,7 +310,7 @@ def __init__(self): @classmethod @abstractmethod - def create(cls, lockdown: LockdownClient, timeout: float, app: str): + def create(cls, lockdown: LockdownServiceProvider, timeout: float, open_safari: bool) -> None: pass @abstractmethod @@ -357,8 +360,8 @@ def __init__(self, driver: WebDriver): @classmethod @asynccontextmanager - async def create(cls, lockdown: LockdownClient, timeout: float, app: str) -> 'AutomationJsShell': - inspector, application = create_webinspector_and_launch_app(lockdown, timeout, app) + async def create(cls, lockdown: LockdownClient, timeout: float, open_safari: bool) -> 'AutomationJsShell': + inspector, application = create_webinspector_and_launch_app(lockdown, timeout, SAFARI) automation_session = inspector.automation_session(application) driver = WebDriver(automation_session) driver.start_session() @@ -382,13 +385,16 @@ def __init__(self, inspector_session: InspectorSession): @classmethod @asynccontextmanager - async def create(cls, lockdown: LockdownClient, timeout: float, app: str) -> 'InspectorJsShell': - inspector, application = create_webinspector_and_launch_app(lockdown, timeout, app) - page = InspectorJsShell.query_page(inspector) - if page is None: + async def create(cls, lockdown: LockdownClient, timeout: float, open_safari: bool) -> 'InspectorJsShell': + inspector = WebinspectorService(lockdown=lockdown) + inspector.connect(timeout) + if open_safari: + _ = inspector.open_app(SAFARI) + application_page = cls.query_page(inspector, bundle_identifier=SAFARI if open_safari else None) + if application_page is None: raise click.exceptions.Exit() - inspector_session = await inspector.inspector_session(application, page) + inspector_session = await inspector.inspector_session(application_page.application, application_page.page) await inspector_session.console_enable() await inspector_session.runtime_enable() @@ -404,19 +410,22 @@ async def navigate(self, url: str): await self.inspector_session.navigate_to_url(url) @staticmethod - def query_page(inspector: WebinspectorService) -> Optional[Page]: - reload_pages(inspector) - available_pages = list(inspector.get_open_pages().get('Safari', [])) + def query_page(inspector: WebinspectorService, bundle_identifier: Optional[str] = None) \ + -> Optional[ApplicationPage]: + available_pages = inspector.get_open_application_pages(timeout=1) + if bundle_identifier is not None: + available_pages = [application_page for application_page in available_pages if + application_page.application.bundle == bundle_identifier] if not available_pages: logger.error('Unable to find available pages (try to unlock device)') - return + return None page_query = [inquirer3.List('page', message='choose page', choices=available_pages, carousel=True)] page = inquirer3.prompt(page_query, theme=GreenPassion(), raise_keyboard_interrupt=True)['page'] return page -async def run_js_shell(js_shell_class: type[JsShell], lockdown: LockdownClient, - timeout: float, url: str): - async with js_shell_class.create(lockdown, timeout, SAFARI) as js_shell_instance: +async def run_js_shell(js_shell_class: type[JsShell], lockdown: LockdownServiceProvider, + timeout: float, url: str, open_safari: bool) -> None: + async with js_shell_class.create(lockdown, timeout, open_safari) as js_shell_instance: await js_shell_instance.start(url) diff --git a/pymobiledevice3/services/webinspector.py b/pymobiledevice3/services/webinspector.py index c8fe08538..9ba976aab 100644 --- a/pymobiledevice3/services/webinspector.py +++ b/pymobiledevice3/services/webinspector.py @@ -104,6 +104,15 @@ def from_application_dictionary(cls, app_dict) -> 'Application': ) +@dataclass +class ApplicationPage: + application: Application + page: Page + + def __str__(self) -> str: + return f'<{self.application.name}({self.application.pid}) TYPE:{self.page.type_.value} URL:{self.page.web_url}>' + + class WebinspectorService: SERVICE_NAME = 'com.apple.webinspector' RSD_SERVICE_NAME = 'com.apple.webinspector.shim.remote' @@ -185,10 +194,10 @@ def automation_session(self, app: Application) -> AutomationSession: self.await_(asyncio.sleep(0)) return AutomationSession(SessionProtocol(self, session_id, app, page)) - async def inspector_session(self, app: Application, page: Page, wait_target: bool = True) -> InspectorSession: + async def inspector_session(self, app: Application, page: Page) -> InspectorSession: session_id = str(uuid.uuid4()).upper() return await InspectorSession.create(SessionProtocol(self, session_id, app, page, method_prefix=''), - wait_target=wait_target) + wait_target=page.type_ != WirTypes.JAVASCRIPT) def get_open_pages(self) -> dict: apps = {} @@ -198,6 +207,20 @@ def get_open_pages(self) -> dict: apps[self.connected_application[app].name] = self.application_pages[app].values() return apps + def get_open_application_pages(self, timeout: float) -> list[ApplicationPage]: + # Query all connected applications + self.await_(self._get_connected_applications()) + + # Give some time for `webinspectord` to reply with all inspectable applications + self.await_(asyncio.sleep(timeout)) + + result = [] + for app in self.connected_application: + if self.application_pages.get(app, False): + for page in self.application_pages[app].values(): + result.append(ApplicationPage(self.connected_application[app], page)) + return result + def open_app(self, bundle: str, timeout: Union[float, int] = 3) -> Application: self.await_(self._request_application_launch(bundle)) self.get_open_pages() @@ -235,17 +258,22 @@ def _handle_report_connected_application_list(self, arg): for key, application in arg['WIRApplicationDictionaryKey'].items(): self.connected_application[key] = Application.from_application_dictionary(application) + # Immediately also query the application pages + self.await_(self._forward_get_listing(application)) + def _handle_report_connected_driver_list(self, arg): pass def _handle_application_sent_listing(self, arg): if arg['WIRApplicationIdentifierKey'] in self.application_pages: + # Update existing application pages for id_, page in arg['WIRListingKey'].items(): if id_ in self.application_pages[arg['WIRApplicationIdentifierKey']]: self.application_pages[arg['WIRApplicationIdentifierKey']][id_].update(page) else: self.application_pages[arg['WIRApplicationIdentifierKey']][id_] = Page.from_page_dictionary(page) else: + # Add new application pages pages = {} for id_, page in arg['WIRListingKey'].items(): pages[id_] = Page.from_page_dictionary(page) @@ -281,6 +309,9 @@ async def _forward_get_listing(self, app_id): async def _request_application_launch(self, bundle: str): await self._send_message('_rpc_requestApplicationLaunch:', {'WIRApplicationBundleIdentifierKey': bundle}) + async def _get_connected_applications(self) -> None: + await self._send_message('_rpc_getConnectedApplications:', {}) + async def _forward_automation_session_request(self, session_id: str, app_id: str): await self._send_message('_rpc_forwardAutomationSessionRequest:', { 'WIRApplicationIdentifierKey': app_id, diff --git a/tests/services/test_web_protocol/test_driver.py b/tests/services/test_web_protocol/test_driver.py index a485a6121..3ac2ed36b 100644 --- a/tests/services/test_web_protocol/test_driver.py +++ b/tests/services/test_web_protocol/test_driver.py @@ -4,7 +4,7 @@ def test_back(webdriver): webdriver.get('https://www.google.com') - webdriver.get('https://www.github.com') + webdriver.get('https://github.com') webdriver.back() assert webdriver.current_url.rstrip('/') == 'https://www.google.com' @@ -18,10 +18,10 @@ def test_current_url(webdriver): def test_forward(webdriver): webdriver.get('https://www.google.com') - webdriver.get('https://www.github.com') + webdriver.get('https://github.com') webdriver.back() webdriver.forward() - assert webdriver.current_url.rstrip('/') == 'https://www.github.com' + assert webdriver.current_url.rstrip('/') == 'https://github.com' def test_find_element(webdriver):