diff --git a/backend/Dockerfile b/backend/Dockerfile index 5ba5bd7f..269349a9 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -52,8 +52,13 @@ RUN cd tinystan && \ emmake make test_models/bernoulli/bernoulli.js -j$(nproc) && \ emstrip test_models/bernoulli/bernoulli.wasm -COPY stan-wasm-server /stan-wasm-server -WORKDIR /stan-wasm-server +COPY stan-wasm-server /app/stan-wasm-server + +# compute a hash of our app to make the model cache use unique URLS +# (to avoid browsers caching across different versions) +RUN find /app/ -type f -print0 | sort -z | xargs -0 sha1sum | sha1sum | cut -d' ' -f1 | tee /app/stan-wasm-server/hash-salt.txt + +WORKDIR /app/stan-wasm-server ENV SWS_PASSCODE=1234 ENV SWS_LOG_LEVEL=debug diff --git a/backend/stan-wasm-server/src/app/logic/compilation.py b/backend/stan-wasm-server/src/app/logic/compilation.py index c8cf1062..1af87f9f 100644 --- a/backend/stan-wasm-server/src/app/logic/compilation.py +++ b/backend/stan-wasm-server/src/app/logic/compilation.py @@ -1,6 +1,7 @@ import asyncio import logging import time +from functools import lru_cache from hashlib import sha1 from pathlib import Path from shutil import copy2 @@ -18,10 +19,27 @@ logger = logging.getLogger(__name__) +@lru_cache +def _get_salt() -> bytes: + """ + Returns a salt to use when hashing Stan programs if one was set during the Docker build. + This helps ensure that models don't get cached by the browser across different deployments. + """ + salt_file = Path(__file__).parent.parent.parent.parent / "hash-salt.txt" + if not salt_file.exists(): + logger.warning("No hash salt file found at %s", salt_file) + salt = b"" + else: + salt = salt_file.read_bytes().strip() + logger.info("Using hash salt from %s: %s", salt_file, repr(salt)) + return salt + + def _compute_stan_program_hash(program_file: Path) -> str: stan_program = program_file.read_text() - # MAYBE: replace stan_program with a canonical form? - return sha1(stan_program.encode()).hexdigest() + hasher = sha1(_get_salt()) + hasher.update(stan_program.encode()) + return hasher.hexdigest() def make_canonical_model_dir(src_file: Path, built_model_dir: Path) -> Path: diff --git a/backend/stan-wasm-server/src/app/main.py b/backend/stan-wasm-server/src/app/main.py index 7951d085..5e10633c 100644 --- a/backend/stan-wasm-server/src/app/main.py +++ b/backend/stan-wasm-server/src/app/main.py @@ -138,18 +138,28 @@ async def compile_stan( return {"model_id": model_dir.name} +def send_interrupt() -> None: + """ + Send an interrupt signal to the parent process. + uvicorn interprets this like Ctrl-C, and gracefully shuts down. + The orchestrator then restarts the server. + """ + import os + import signal + + os.kill(os.getppid(), signal.SIGINT) + + @app.post("/restart") async def restart( - settings: DependsOnSettings, authorization: str = Header(None) -) -> None: + settings: DependsOnSettings, + background_tasks: BackgroundTasks, + authorization: str = Header(None), +) -> DictResponse: if settings.restart_token is None: raise StanPlaygroundAuthenticationException("Restart token not set at startup") check_authorization(authorization, settings.restart_token) - import os - import signal + background_tasks.add_task(send_interrupt) - # send an interrupt signal to the parent process - # uvicorn interprets this like Ctrl-C, and gracefully shuts down - os.kill(os.getppid(), signal.SIGINT) - # actual restart is handled by the orchestrator + return {"status": "restarting"}