Skip to content

Commit

Permalink
Allow picker function to be loaded with `somedir/somefile.py:somefunc…
Browse files Browse the repository at this point in the history
…tion`
  • Loading branch information
sverhoeven committed Mar 2, 2024
1 parent 5e18bfd commit 76fc0c3
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 6 deletions.
6 changes: 4 additions & 2 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,10 @@ specify to which destination a job should be submitted. A Python function can be
used to pick a destination. By default jobs are submitted to the first
destination.

To use a custom picker function set `destination_picker`. The value should be
formatted as `<module>:<function>`. The picker function should have type
To use a custom picker function set `destination_picker`.
The value should be formatted as `<module>:<function>` or
`<path to python file>:<function>`.
The picker function should have type
[bartender.picker.DestinationPicker](
https://github.com/i-VRESSE/bartender/blob/bdbef5176e05c498b37f4ada2bf7c09ad0e7b853/src/bartender/picker.py#L8
). For example to rotate over each
Expand Down
18 changes: 15 additions & 3 deletions src/bartender/picker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from importlib import import_module
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from typing import TYPE_CHECKING, Callable

Expand Down Expand Up @@ -76,16 +77,27 @@ def __call__(


def import_picker(destination_picker_name: str) -> DestinationPicker:
"""Import a picker function based on a `<module>:<function>` string.
"""Import a picker function.
Args:
destination_picker_name: function import as string.
Format `<module>:<function>` or `<path to python file>:<function>`
Returns:
Function that can be used to pick to which destination a job should be
submitted.
Raises:
ValueError: If the function could not be imported.
"""
# TODO allow somedir/somefile.py:pick_round_robin
(module_name, function_name) = destination_picker_name.split(":")
module = import_module(module_name)
if module_name.endswith(".py"):
file_path = Path(module_name)
spec = spec_from_file_location(file_path.name, file_path)
if spec is None or spec.loader is None:
raise ValueError(f"Could not load {file_path}")
module = module_from_spec(spec)
spec.loader.exec_module(module)
else:
module = import_module(module_name)
return getattr(module, function_name)
30 changes: 29 additions & 1 deletion tests/test_picker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from pathlib import Path
from textwrap import dedent

import pytest

from bartender.context import ApplicatonConfiguration, Context
from bartender.destinations import Destination
from bartender.picker import PickRound, pick_first
from bartender.picker import PickRound, import_picker, pick_first
from bartender.user import User


Expand Down Expand Up @@ -105,3 +106,30 @@ async def test_thirdcall_returns_first(self, context: Context, user: User) -> No

expected = "d1"
assert actual == expected


def test_import_picker_module() -> None:
fn = import_picker("bartender.picker:pick_first")
assert fn.__name__ == "pick_first"


def test_import_picker_file(
tmp_path: Path,
user: User,
) -> None:
code = """\
def mypicker(job_dir, application_name, submitter, context):
return "mydestination"
"""
path = tmp_path / "mymodule.py"
path.write_text(dedent(code))

fn = import_picker(f"{path}:mypicker")
context = Context(
destination_picker=fn,
applications={},
destinations={},
job_root_dir=tmp_path,
)
result = fn(tmp_path, "someapp", user, context)
assert result == "mydestination"

0 comments on commit 76fc0c3

Please sign in to comment.