From 76fc0c3d8940c254bdfbbbb79988e451f84b975c Mon Sep 17 00:00:00 2001 From: Stefan Verhoeven Date: Sat, 2 Mar 2024 11:14:43 +0100 Subject: [PATCH] Allow picker function to be loaded with `somedir/somefile.py:somefunction` --- docs/configuration.md | 6 ++++-- src/bartender/picker.py | 18 +++++++++++++++--- tests/test_picker.py | 30 +++++++++++++++++++++++++++++- 3 files changed, 48 insertions(+), 6 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 20802e64..b36aae76 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -400,8 +400,10 @@ specify to which destination a job should be submitted. A Python function can be used to pick a destination. By default jobs are submitted to the first destination. -To use a custom picker function set `destination_picker`. The value should be -formatted as `:`. The picker function should have type +To use a custom picker function set `destination_picker`. +The value should be formatted as `:` or +`:`. +The picker function should have type [bartender.picker.DestinationPicker]( https://github.com/i-VRESSE/bartender/blob/bdbef5176e05c498b37f4ada2bf7c09ad0e7b853/src/bartender/picker.py#L8 ). For example to rotate over each diff --git a/src/bartender/picker.py b/src/bartender/picker.py index b46561fc..908ae1d4 100644 --- a/src/bartender/picker.py +++ b/src/bartender/picker.py @@ -1,4 +1,5 @@ from importlib import import_module +from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from typing import TYPE_CHECKING, Callable @@ -76,16 +77,27 @@ def __call__( def import_picker(destination_picker_name: str) -> DestinationPicker: - """Import a picker function based on a `:` string. + """Import a picker function. Args: destination_picker_name: function import as string. + Format `:` or `:` Returns: Function that can be used to pick to which destination a job should be submitted. + + Raises: + ValueError: If the function could not be imported. """ - # TODO allow somedir/somefile.py:pick_round_robin (module_name, function_name) = destination_picker_name.split(":") - module = import_module(module_name) + if module_name.endswith(".py"): + file_path = Path(module_name) + spec = spec_from_file_location(file_path.name, file_path) + if spec is None or spec.loader is None: + raise ValueError(f"Could not load {file_path}") + module = module_from_spec(spec) + spec.loader.exec_module(module) + else: + module = import_module(module_name) return getattr(module, function_name) diff --git a/tests/test_picker.py b/tests/test_picker.py index aa43443c..1655aba0 100644 --- a/tests/test_picker.py +++ b/tests/test_picker.py @@ -1,10 +1,11 @@ from pathlib import Path +from textwrap import dedent import pytest from bartender.context import ApplicatonConfiguration, Context from bartender.destinations import Destination -from bartender.picker import PickRound, pick_first +from bartender.picker import PickRound, import_picker, pick_first from bartender.user import User @@ -105,3 +106,30 @@ async def test_thirdcall_returns_first(self, context: Context, user: User) -> No expected = "d1" assert actual == expected + + +def test_import_picker_module() -> None: + fn = import_picker("bartender.picker:pick_first") + assert fn.__name__ == "pick_first" + + +def test_import_picker_file( + tmp_path: Path, + user: User, +) -> None: + code = """\ + def mypicker(job_dir, application_name, submitter, context): + return "mydestination" + """ + path = tmp_path / "mymodule.py" + path.write_text(dedent(code)) + + fn = import_picker(f"{path}:mypicker") + context = Context( + destination_picker=fn, + applications={}, + destinations={}, + job_root_dir=tmp_path, + ) + result = fn(tmp_path, "someapp", user, context) + assert result == "mydestination"