From e1206bec1f32b711f95f312866031a32087f067e Mon Sep 17 00:00:00 2001 From: Jonathan Sick Date: Thu, 21 Mar 2024 16:06:27 -0400 Subject: [PATCH] Extract and use XSRF token from hub and lab Rather than create a synthetic XSRF token, we actually look at the token returned by the Hub and Lab login endpoints and use those. See https://github.com/lsst-sqre/mobu/pull/340 --- src/noteburst/jupyterclient/jupyterlab.py | 137 ++++++++----------- src/noteburst/jupyterclient/labcontroller.py | 10 +- tests/jupyterclient/jupyterclient_test.py | 55 -------- 3 files changed, 68 insertions(+), 134 deletions(-) delete mode 100644 tests/jupyterclient/jupyterclient_test.py diff --git a/src/noteburst/jupyterclient/jupyterlab.py b/src/noteburst/jupyterclient/jupyterlab.py index 1bbbfed..54a447f 100644 --- a/src/noteburst/jupyterclient/jupyterlab.py +++ b/src/noteburst/jupyterclient/jupyterlab.py @@ -6,16 +6,15 @@ import contextlib import datetime import json -import string from collections.abc import AsyncGenerator, AsyncIterator from dataclasses import dataclass -from random import SystemRandom -from typing import Annotated, Any, Self +from typing import Annotated, Any from urllib.parse import urljoin, urlparse from uuid import uuid4 import httpx import websockets +from httpx import Cookies from pydantic import BaseModel, Field from structlog import BoundLogger from websockets.client import WebSocketClientProtocol @@ -422,83 +421,35 @@ def __init__( str(noteburst_config.environment_url), self.config.url_prefix ) - self._http_client: httpx.AsyncClient | None = None - self._lab_controller_client: LabControllerClient | None = None - self._common_headers: dict[str, str] # set and reset in the context - - @property - def http_client(self) -> httpx.AsyncClient: - """The HTTPX client instance associated with the Jupyter session..""" - if self._http_client is None: - self._open_clients() - if self._http_client is None: - raise RuntimeError("http_client is not set") - return self._http_client - - @property - def lab_controller(self) -> LabControllerClient: - """The Jupyter Lab Controller client, only available in the - JupyterClient context. - """ - if self._lab_controller_client is None: - self._open_clients() - if self._lab_controller_client is None: - raise RuntimeError("LabControllerClient is not set set up") - return self._lab_controller_client - - async def __aenter__(self) -> Self: - self._open_clients() - return self - - def _open_clients(self) -> None: - if (self._http_client is not None) or ( - self._lab_controller_client is not None - ): - raise RuntimeError( - "JupyterClient is already open. Call close() before " - "re-opening?" - ) - - alphabet = string.ascii_uppercase + string.digits - xsrf_token = "".join(SystemRandom().choices(alphabet, k=16)) - headers = { - "x-xsrftoken": xsrf_token, + self._headers = { "Authorization": f"Bearer {self.user.token}", } - self._common_headers = headers - cookies = {"_xsrf": xsrf_token} - - self._http_client = httpx.AsyncClient( - headers=headers, - cookies=cookies, + self.http_client = httpx.AsyncClient( + headers=self._headers, follow_redirects=True, - timeout=30.0, # default is 5, but Hub can be slow + timeout=30, ) + self._lab_controller_client: LabControllerClient | None = None + self._hub_xsrf: str | None = None + self._lab_xsrf: str | None = None + + @property + def lab_controller(self) -> LabControllerClient: + if self._lab_controller_client: + return self._lab_controller_client - # Create a LabController client - # We also send the XSRF token to Lab Controller because of how we're - # sharing the session, but that shouldn't matter. self._lab_controller_client = LabControllerClient( http_client=self.http_client, token=noteburst_config.gafaelfawr_token.get_secret_value(), url_prefix=noteburst_config.nublado_controller_path_prefix, ) - - async def __aexit__(self, *exc_info: object) -> None: - await self.close() + return self._lab_controller_client async def close(self) -> None: - """Manually close the client. - - Do not use this method for manually closing the Jupyter client when - using JupyterClient as an async context manager. The client is - closed automatically. - """ + """Close the client.""" self._lab_controller_client = None await self.http_client.aclose() - self._http_client = None - self._common_headers = {} def url_for(self, path: str) -> str: """Create a URL relative to the jupyter_url.""" @@ -515,14 +466,17 @@ def url_for_websocket(self, path: str) -> str: async def log_into_hub(self) -> None: """Log into JupyterHub or raise a JupyterError.""" self.logger.debug("Logging into JupyterHub") - r = await self.http_client.get( - self.url_for("hub/login"), follow_redirects=False - ) + r = await self.http_client.get(self.url_for("hub/home")) # JupyterHub returns a 302 redirect to the login page on success, # but we don't want to follow that redirect. This request is just # to set cookies. if r.status_code >= 400: raise JupyterError.from_response(self.user.username, r) + cookies = Cookies() + cookies.extract_cookies(r) + xsrf = cookies.get("_xsrf") + if xsrf: + self._hub_xsrf = xsrf async def log_into_lab(self) -> None: """Log into JupyterLab or raise a JupyterError.""" @@ -532,6 +486,11 @@ async def log_into_lab(self) -> None: ) if r.status_code != 200: raise JupyterError.from_response(self.user.username, r) + cookies = Cookies() + cookies.extract_cookies(r) + xsrf = cookies.get("_xsrf") + if xsrf: + self._lab_xsrf = xsrf async def spawn_lab(self) -> JupyterImage: """Spawn a JupyterLab pod.""" @@ -540,13 +499,16 @@ async def spawn_lab(self) -> JupyterImage: # Retrieving the spawn page before POSTing to it appears to trigger # some necessary internal state construction (and also more accurately # simulates a user interaction). See DM-23864. - _ = await self.http_client.get(spawn_url) + headers = dict(self._headers) + if self._hub_xsrf: + headers["X-XSRFToken"] = self._hub_xsrf + _ = await self.http_client.get(spawn_url, headers=headers) # POST the options form to the spawn page. This should redirect to # the spawn-pending page, which will return a 200. image = await self._get_spawn_image() data = self._build_jupyter_spawn_form(image) - r = await self.http_client.post(spawn_url, data=data) + r = await self.http_client.post(spawn_url, data=data, headers=headers) if r.status_code != 200: raise JupyterError.from_response(self.user.username, r) @@ -565,6 +527,8 @@ async def spawn_progress(self) -> AsyncIterator[SpawnProgressMessage]: ) referer_url = self.url_for("hub/home") headers = {"Referer": referer_url} + if self._hub_xsrf: + headers["X-XSRFToken"] = self._hub_xsrf while True: async with self.http_client.stream( "GET", progress_url, headers=headers @@ -626,6 +590,8 @@ async def stop_lab(self) -> None: server_url = self.url_for(f"hub/api/users/{user}/server") referer_url = self.url_for("hub/home") headers = {"Referer": referer_url} + if self._hub_xsrf: + headers["X-XSRFToken"] = self._hub_xsrf r = await self.http_client.delete(server_url, headers=headers) if r.status_code not in [200, 202, 204]: raise JupyterError.from_response(self.user.username, r) @@ -642,6 +608,8 @@ async def is_lab_stopped(self, *, final: bool = False) -> bool: user_url = self.url_for(f"hub/api/users/{self.user.username}") referer_url = self.url_for("hub/home") headers = {"Referer": referer_url} + if self._hub_xsrf: + headers["X-XSRFToken"] = self._hub_xsrf r = await self.http_client.get(user_url, headers=headers) if r.status_code != 200: raise JupyterError.from_response(self.user.username, r) @@ -669,7 +637,12 @@ async def open_lab_session( "path": notebook_name if notebook_name else uuid4().hex, "type": session_type, } - r = await self.http_client.post(session_url, json=body) + headers = {} + if self._lab_xsrf: + headers["X-XSRFToken"] = self._lab_xsrf + r = await self.http_client.post( + session_url, json=body, headers=headers + ) if r.status_code != 201: raise JupyterError.from_response(self.user.username, r) session_resource = r.json() @@ -685,10 +658,11 @@ async def open_lab_session( # Generate a mock request and copy its headers / cookies over to the # websocket connection. mock_request = self.http_client.build_request("GET", http_channels_uri) - copied_headers = ["x-xsrftoken", "authorization", "cookie"] - websocket_headers = { - header: mock_request.headers[header] for header in copied_headers + headers = { + h: mock_request.headers[h] for h in ("authorization", "cookie") } + if self._lab_xsrf: + headers["X-XSRFToken"] = self._lab_xsrf session_id: str | None = None # will be set if a session is opened self.logger.debug("Trying to create websocket connection") @@ -699,7 +673,7 @@ async def open_lab_session( # long lived clients # https://websockets.readthedocs.io/en/stable/reference/client.html#using-a-connection async with websockets.connect( - wss_channels_uri, extra_headers=websocket_headers + wss_channels_uri, extra_headers=headers ) as websocket: self.logger.info("Created websocket connection") jupyter_lab_session = JupyterLabSession( @@ -720,7 +694,9 @@ async def open_lab_session( session_id_url = self.url_for( f"user/{self.user.username}/api/sessions/{session_id}" ) - r = await self.http_client.delete(session_id_url) + r = await self.http_client.delete( + session_id_url, headers=headers + ) if r.status_code != 204: raise JupyterError.from_response(self.user.username, r) @@ -747,10 +723,14 @@ async def execute_notebook( Notebook execution extension. """ exec_url = self.url_for(f"user/{self.user.username}/rubin/execution") + headers = {} + if self._lab_xsrf: + headers["X-XSRFToken"] = self._lab_xsrf try: r = await self.http_client.post( exec_url, content=json.dumps(notebook).encode("utf-8"), + headers=headers, ) except httpx.HTTPError as e: # This often occurs from timeouts, so we want to convert the @@ -777,7 +757,10 @@ async def get_jupyterlab_env(self) -> dict[str, Any]: environment_url = self.url_for( f"user/{self.user.username}/rubin/environment" ) - r = await self.http_client.get(environment_url) + headers = {} + if self._lab_xsrf: + headers["X-XSRFToken"] = self._lab_xsrf + r = await self.http_client.get(environment_url, headers=headers) if r.status_code != 200: raise JupyterError.from_response(self.user.username, r) return r.json() diff --git a/src/noteburst/jupyterclient/labcontroller.py b/src/noteburst/jupyterclient/labcontroller.py index faa5984..79365a8 100644 --- a/src/noteburst/jupyterclient/labcontroller.py +++ b/src/noteburst/jupyterclient/labcontroller.py @@ -137,7 +137,11 @@ class LabControllerClient: """ def __init__( - self, *, http_client: httpx.AsyncClient, token: str, url_prefix: str + self, + *, + http_client: httpx.AsyncClient, + token: str, + url_prefix: str, ) -> None: self._http_client = http_client self._token = token @@ -204,7 +208,9 @@ async def get_by_reference(self, reference: str) -> JupyterImage: return image async def _get_images(self) -> LabControllerImages: - headers = {"Authorization": f"bearer {self._token}"} + headers = { + "Authorization": f"bearer {self._token}", + } url = urljoin( str(config.environment_url), f"{self._url_prefix}/spawner/v1/images", diff --git a/tests/jupyterclient/jupyterclient_test.py b/tests/jupyterclient/jupyterclient_test.py deleted file mode 100644 index 34b0ca2..0000000 --- a/tests/jupyterclient/jupyterclient_test.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Test the JupyterClient.""" - -from __future__ import annotations - -import httpx -import pytest -import respx -import structlog - -from noteburst.config import JupyterImageSelector -from noteburst.jupyterclient.jupyterlab import JupyterClient, JupyterConfig -from noteburst.jupyterclient.user import User -from tests.support.gafaelfawr import mock_gafaelfawr -from tests.support.jupyter import MockJupyter -from tests.support.labcontroller import MockLabController - - -@pytest.mark.asyncio -async def test_jupyterclient( - respx_mock: respx.Router, - jupyter: MockJupyter, - labcontroller: MockLabController, -) -> None: - user = User(username="someuser", uid="1234") - mock_gafaelfawr( - respx_mock=respx_mock, username=user.username, uid=user.uid - ) - - logger = structlog.get_logger(__name__) - - jupyter_config = JupyterConfig( - image_selector=JupyterImageSelector.recommended - ) - - async with httpx.AsyncClient() as http_client: - authed_user = await user.login( - scopes=["exec:notebook"], - http_client=http_client, - token_lifetime=3600, - ) - async with JupyterClient( - user=authed_user, logger=logger, config=jupyter_config - ) as jupyter_client: - await jupyter_client.log_into_hub() - - image_info = await jupyter_client.spawn_lab() - print(image_info) - async for progress in jupyter_client.spawn_progress(): - print(progress) - - await jupyter_client.log_into_lab() - - # Note: the test code for running open_lab_session isn't available - - await jupyter_client.stop_lab()