diff --git a/src/compwa_policy/check_dev_files/binder.py b/src/compwa_policy/check_dev_files/binder.py index 83b881a6..f29e2491 100644 --- a/src/compwa_policy/check_dev_files/binder.py +++ b/src/compwa_policy/check_dev_files/binder.py @@ -23,35 +23,12 @@ def main(python_version: PythonVersion, apt_packages: list[str]) -> None: with Executor() as do: - do(_check_optional_dependencies) do(_update_apt_txt, apt_packages) do(_update_post_build) do(_make_executable, CONFIG_PATH.binder / "postBuild") do(_update_runtime_txt, python_version) -def _check_optional_dependencies() -> None: - required_for_binder = "This is required for Binder." - if not CONFIG_PATH.pyproject.exists(): - msg = f"{CONFIG_PATH.pyproject} does not exist. {required_for_binder}" - raise PrecommitError(msg) - pyproject = Pyproject.load() - table_key = "project.optional-dependencies" - if not pyproject.has_table(table_key): - msg = f"Missing [{table_key}] in {CONFIG_PATH.pyproject}. {required_for_binder}" - raise PrecommitError - optional_dependencies = pyproject.get_table(table_key) - optional_dependency_sections = set(optional_dependencies) - expected_sections = {"jupyter", "notebooks"} - missing_sections = expected_sections - optional_dependency_sections - if missing_sections: - msg = ( - f"Missing sections in [{table_key}]: {', '.join(sorted(missing_sections))}" - f". {required_for_binder}" - ) - raise PrecommitError(msg) - - def _update_apt_txt(apt_packages: list[str]) -> None: apt_txt = CONFIG_PATH.binder / "apt.txt" if not apt_packages: @@ -74,11 +51,13 @@ def _update_post_build() -> None: curl -LsSf https://astral.sh/uv/install.sh | sh source $HOME/.cargo/env """).strip() + + notebook_extras = __get_notebook_extras() if "uv.lock" in set(git_ls_files(untracked=True)): + expected_content += "\nuv export \\" + for extra in notebook_extras: + expected_content += f"\n --extra {extra} \\" expected_content += dedent(R""" - uv export \ - --extra jupyter \ - --extra notebooks \ > requirements.txt uv pip install \ --requirement requirements.txt \ @@ -86,9 +65,12 @@ def _update_post_build() -> None: uv cache clean """) else: - expected_content += dedent(R""" + package = "." + if notebook_extras: + package = f"'.[{','.join(notebook_extras)}]'" + expected_content += dedent(Rf""" uv pip install \ - --editable '.[jupyter,notebooks]' \ + --editable {package} \ --no-cache \ --system """) @@ -98,6 +80,18 @@ def _update_post_build() -> None: ) +def __get_notebook_extras() -> list[str]: + if not CONFIG_PATH.pyproject.exists(): + return [] + pyproject = Pyproject.load() + table_key = "project.optional-dependencies" + if not pyproject.has_table(table_key): + return [] + optional_dependencies = pyproject.get_table(table_key) + allowed_sections = {"jupyter", "notebooks"} + return sorted(allowed_sections & set(optional_dependencies)) + + def _make_executable(path: Path) -> None: if os.access(path, os.X_OK): return