diff --git a/lib/pavilion/cmd_utils.py b/lib/pavilion/cmd_utils.py index 20550dbaa..a15aae092 100644 --- a/lib/pavilion/cmd_utils.py +++ b/lib/pavilion/cmd_utils.py @@ -10,9 +10,9 @@ from pathlib import Path from collections import defaultdict from enum import Enum, auto -from itertools import chain, filterfalse, starmap, tee +from itertools import chain, starmap, tee from typing import (List, TextIO, Union, Iterator, Iterable, - Callable, TypeVar, Tuple, Optional, Any) + Callable, TypeVar, Optional, Any) from pavilion import config from pavilion import dir_db @@ -27,25 +27,13 @@ PavilionError, TestGroupError from pavilion.test_run import TestRun, load_tests, TestAttributes from pavilion.types import ID_Pair +from pavilion.micro import listmap LOGGER = logging.getLogger(__name__) T = TypeVar('T') -def partition(pred: Callable[[T], bool], lst: Iterable[T]) -> Tuple[Iterator[T], Iterator[T]]: - """Partition the sequence into two sequences: one consisting of the elements - for which the given predicate is true and one consisting of those for - which it is false.""" - - f_true, f_false = tee(lst) - - return filter(pred, f_true), filterfalse(pred, f_false) - -def flatten(lst: Iterable[Iterable[T]]) -> Iterator[T]: - """Convert a singly nested iterable into an unnested iterable.""" - return chain.from_iterable(lst) - def expand_range(rng: str) -> Union[List[str], Iterator[str]]: """Expand an integer range (given as a string) into the sequence of (string representations) of integers specified by @@ -90,12 +78,6 @@ def get_last_id(pav_cfg: PavConfig, errfile = None) -> Optional[str]: return raw_id -def remove_all(lst: Iterable[T], item: T) -> Iterator[T]: - return filter(lambda x: x != item, lst) - -def unique(lst: Iterable[T]) -> List[T]: - return list(set(lst)) - def convert_last(raw_ids: Iterable[str], pav_cfg: PavConfig, errfile = None) -> List[str]: raw_ids = list(raw_ids) lastless = list(remove_all(raw_ids, 'last')) @@ -346,7 +328,7 @@ def arg_filtered_series(pav_cfg: config.PavConfig, args: argparse.Namespace, found_series = get_all_series(pav_cfg, sort_by, filter_func, limit, verbose) else: - found_series = list(map(lambda x: series.SeriesInfo.load(pav_cfg, x), args.series)) + found_series = listmap(lambda x: series.SeriesInfo.load(pav_cfg, x), args.series) return found_series diff --git a/lib/pavilion/micro.py b/lib/pavilion/micro.py new file mode 100644 index 000000000..6ca06432a --- /dev/null +++ b/lib/pavilion/micro.py @@ -0,0 +1,81 @@ +"""A collection of 'microfunctions' primarily designed to abstract common +tasks and patterns, for the purpose of conciseness and readability.""" + +from pathlib import Path +from itertools import filterfalse, chain, tee +from typing import (List, Union, TypeVar, Iterator, Iterable, Callable, Optional, + Hashable, Dict, Tuple) + +T = TypeVar('T') +U = TypeVar('U') + + +def partition(pred: Callable[[T], bool], lst: Iterable[T]) -> Tuple[Iterator[T], Iterator[T]]: + """Partition the sequence into two sequences: one consisting of the elements + for which the given predicate is true and one consisting of those for + which it is false.""" + + f_true, f_false = tee(lst) + + return filter(pred, f_true), filterfalse(pred, f_false) + +def flatten(lst: Iterable[Iterable[T]]) -> Iterator[T]: + """Convert a singly nested iterable into an unnested iterable.""" + return chain.from_iterable(lst) + +def remove_all(lst: Iterable[T], item: T) -> Iterator[T]: + """Remove all instances of the given item from the iterable.""" + return filter(lambda x: x != item, lst) + +def unique(lst: Iterable[T]) -> List[T]: + """Return a list of the unique items in the original list.""" + return list(set(lst)) + +def replace(lst: Iterable[T], old: T, new: T) -> Iterator[T]: + """Replace all instances of old with new.""" + return map(lambda x: new if x == old else x, lst) + +def remove_none(lst: Iterable[T]) -> Iterator[T]: + """Remove all instances of None from the iterable.""" + return filter(lambda x: x is not None, lst) + +def first(pred: Callable[[T], bool], lst: Iterable[T]) -> Optional[T]: + """Return the first item of the list that satisfies the given + predicate, or None if no item does.""" + + for item in filter(pred, lst): + return item + +def apply_to_first(func: Callable[[T], U], pred: Callable[[T], bool], + lst: Iterable[T]) -> Optional[U]: + """Apply the function to the first element of the list that satisfies + the given predicate. If no element satisfies the predicate, return None.""" + + fst = first(pred, lst) + + if fst is not None: + return func(fst) + +def get_nested(keys: Iterable[Hashable], nested_dict: Dict) -> Dict: + """Gets the values associated with the given sequence of keys + out of a nested dictionary. If any key in the sequence does + not exist during the process, returns an empty dictionary.""" + + for key in keys: + nested_dict = nested_dict.get(key, {}) + + return nested_dict + +def listmap(func: Callable[[T], U], lst: Iterable[T]) -> List[U]: + """Map a function over an iterable, but return a list instead + of a map object.""" + return list(map(func, lst)) + +def set_default(val: Optional[T], default: T) -> T: + """Set the input value to default, if the original value is None. + Otherwise, return the value unchanged.""" + + if val is None: + return default + + return val