diff --git a/devtool.py b/devtool.py index a37228f3..5094e06a 100644 --- a/devtool.py +++ b/devtool.py @@ -43,7 +43,7 @@ def analyse_args( command: list[str | Any], raise_exception: bool = True, context_style: Literal["bracket", "parentheses"] | None = None, - **kwargs + **kwargs, ): meta = CommandMeta(keep_crlf=False, fuzzy_match=False, raise_exception=raise_exception, context_style=context_style) argv: Argv[DataCollection] = Argv(meta, dev_space) @@ -66,7 +66,7 @@ def analyse_header( compact: bool = False, raise_exception: bool = True, context_style: Literal["bracket", "parentheses"] | None = None, - **kwargs + **kwargs, ): meta = CommandMeta(keep_crlf=False, fuzzy_match=False, raise_exception=raise_exception, context_style=context_style) argv: Argv[DataCollection] = Argv(meta, dev_space, separators=sep) @@ -86,7 +86,7 @@ def analyse_option( command: DataCollection[str | Any], raise_exception: bool = True, context_style: Literal["bracket", "parentheses"] | None = None, - **kwargs + **kwargs, ): meta = CommandMeta(keep_crlf=False, fuzzy_match=False, raise_exception=raise_exception, context_style=context_style) argv: Argv[DataCollection] = Argv(meta, dev_space) @@ -95,12 +95,12 @@ def analyse_option( _analyser.command.separators = " " _analyser.need_main_args = False _analyser.command.options.append(option) - default_compiler(_analyser, argv.param_ids) + default_compiler(_analyser) _analyser.command.options.clear() try: argv.enter(kwargs) argv.build(command) - alo(_analyser, argv, option) + alo(_analyser, argv, option, False) return _analyser.options_result[option.dest] except Exception as e: if raise_exception: @@ -113,7 +113,7 @@ def analyse_subcommand( command: DataCollection[str | Any], raise_exception: bool = True, context_style: Literal["bracket", "parentheses"] | None = None, - **kwargs + **kwargs, ): meta = CommandMeta(keep_crlf=False, fuzzy_match=False, raise_exception=raise_exception, context_style=context_style) argv: Argv[DataCollection] = Argv(meta, dev_space) @@ -122,7 +122,7 @@ def analyse_subcommand( _analyser.command.separators = " " _analyser.need_main_args = False _analyser.command.options.append(subcommand) - default_compiler(_analyser, argv.param_ids) + default_compiler(_analyser) _analyser.command.options.clear() try: argv.enter(kwargs) diff --git a/src/arclet/alconna/_internal/_analyser.py b/src/arclet/alconna/_internal/_analyser.py index e47c7983..4c0aa872 100644 --- a/src/arclet/alconna/_internal/_analyser.py +++ b/src/arclet/alconna/_internal/_analyser.py @@ -9,10 +9,11 @@ from ..action import Action from ..args import Args from ..arparma import Arparma -from ..base import Completion, Help, Option, Shortcut, Subcommand -from ..completion import comp_ctx +from ..base import Option, Subcommand +from ..completion import comp_ctx, prompt from ..exceptions import ( ArgumentMissing, + AnalyseException, FuzzyMatchSuccess, InvalidHeader, InvalidParam, @@ -27,7 +28,6 @@ analyse_args, analyse_param, handle_opt_default, - prompt, ) from ._util import levenshtein @@ -43,7 +43,7 @@ def default_compiler(analyser: SubAnalyser): analyser (SubAnalyser): 任意子解析器 """ for opts in analyser.command.options: - if isinstance(opts, Option) and not isinstance(opts, (Help, Shortcut, Completion)): + if isinstance(opts, Option): if opts.compact or opts.action.type == 2 or not set(analyser.command.separators).issuperset(opts.separators): # noqa: E501 analyser.compact_params.append(opts) for alias in opts.aliases: @@ -153,21 +153,21 @@ def process(self, argv: Argv[TDC], name_validated: bool = True) -> Self: ParamsUnmatched: 名称不匹配 FuzzyMatchSuccess: 模糊匹配成功 """ - sub = argv.current_node = self.command + sub = self.command if not name_validated: name, _ = argv.next(sub.separators) if name not in sub.aliases: argv.rollback(name) if not argv.fuzzy_match: - raise InvalidParam(lang.require("subcommand", "name_error").format(source=sub.dest, target=name)) + raise InvalidParam(lang.require("subcommand", "name_error").format(source=sub.dest, target=name), sub) for al in sub.aliases: if levenshtein(name, al) >= argv.fuzzy_threshold: - raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=al, target=name)) - raise InvalidParam(lang.require("subcommand", "name_error").format(source=sub.dest, target=name)) + raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=al, target=name), sub) + raise InvalidParam(lang.require("subcommand", "name_error").format(source=sub.dest, target=name), sub) self.value_result = sub.action.value argv.stack_params.enter(self.compile_params) - while analyse_param(self, argv, self.command.separators): + while analyse_param(self, argv, self.command.separators) and argv.current_index != argv.ndata: pass if self.default_main_only and not self.args_result: self.args_result = analyse_args(argv, self.self_args) @@ -175,26 +175,12 @@ def process(self, argv: Argv[TDC], name_validated: bool = True) -> Self: raise ArgumentMissing( self.self_args.argument[0].field.get_missing_tips( lang.require("subcommand", "args_missing").format(name=self.command.dest) - ) + ), + sub ) argv.stack_params.leave() return self - def get_sub_analyser(self, target: Subcommand) -> SubAnalyser | None: - """获取子解析器 - - Args: - target (Subcommand): 目标子命令 - - Returns: - SubAnalyser[TDC] | None: 子解析器 - """ - if target == self.command: - return self - for param in self.compile_params.values(): - if isinstance(param, SubAnalyser): - return param.get_sub_analyser(target) - class Analyser(SubAnalyser): """命令解析器""" @@ -215,7 +201,7 @@ def __init__(self, alconna: Alconna, compiler: TCompile | None = None): def compile(self): self.extra_allow = not self.command.meta.strict or not self.command.namespace_config.strict self._compiler(self) - command_manager.resolve(self.command).stack_params.enter(self.compile_params) + command_manager.resolve(self.command).stack_params.base = self.compile_params return self def __repr__(self): @@ -242,39 +228,43 @@ def process(self, argv: Argv[TDC], name_validated: bool = True) -> Exception | N pass except FuzzyMatchSuccess as e: return e - # except SpecialOptionTriggered as sot: - # return _SPECIAL[sot.args[0]](self, argv) except (InvalidParam, ArgumentMissing) as e1: - # if (rest := argv.release()) and isinstance(rest[-1], str): - # if rest[-1] in argv.completion_names and "completion" not in argv.namespace.disable_builtin_options: - # argv.bak_data[-1] = argv.bak_data[-1][: -len(rest[-1])].rstrip() - # return handle_completion(self, argv) - # if (handler := argv.special.get(rest[-1])) and handler not in argv.namespace.disable_builtin_options: - # return _SPECIAL[handler](self, argv) if comp_ctx.get(None): if isinstance(e1, InvalidParam): - argv.free(argv.current_node.separators if argv.current_node else None) - return PauseTriggered(prompt(self, argv), e1, argv) + argv.free(e1.context_node.separators if e1.context_node else None) + return PauseTriggered( + prompt(self.command, argv, [*self.args_result.keys()], [*self.options_result.keys(), *self.subcommands_result.keys()], e1.context_node), + e1, + argv + ) return e1 if self.default_main_only and not self.args_result: - self.args_result = analyse_args(argv, self.self_args) + try: + self.args_result = analyse_args(argv, self.self_args) + except FuzzyMatchSuccess as e1: + return e1 + except AnalyseException as e2: + e2.context_node = None + if not argv.error: + argv.error = e2 if argv.current_index == argv.ndata and (not self.need_main_args or self.args_result): return rest = argv.release() if len(rest) > 0: - # if isinstance(rest[-1], str) and rest[-1] in argv.completion_names: - # argv.bak_data[-1] = argv.bak_data[-1][: -len(rest[-1])].rstrip() - # return handle_completion(self, argv, rest[-2]) - exc = ParamsUnmatched(lang.require("analyser", "param_unmatched").format(target=argv.next(move=False)[0])) + exc = ParamsUnmatched(lang.require("analyser", "param_unmatched").format(target=argv.next()[0])) else: exc = ArgumentMissing( self.self_args.argument[0].field.get_missing_tips(lang.require("analyser", "param_missing")) ) - if comp_ctx.get(None) and isinstance(exc, ArgumentMissing): - return PauseTriggered(prompt(self, argv), exc, argv) + if comp_ctx.get(None): + return PauseTriggered( + prompt(self.command, argv, [*self.args_result.keys()], [*self.options_result.keys(), *self.subcommands_result.keys()]), + exc, + argv + ) return exc def export( @@ -290,23 +280,30 @@ def export( fail (bool, optional): 是否解析失败. Defaults to False. exception (Exception | None, optional): 解析失败时的异常. Defaults to None. """ + if argv.error: + fail = True + exception = argv.error result = Arparma(self.command._hash, argv.origin, not fail, self.header_result, ctx=argv.exit()) if fail: + if self.command.meta.raise_exception and not isinstance(exception, FuzzyMatchSuccess): + raise exception result.error_info = exception result.error_data = argv.release() - else: - if self.default_opt_result: - handle_opt_default(self.default_opt_result, self.options_result) - if self.default_sub_result: - for k, v in self.default_sub_result.items(): - if k not in self.subcommands_result: - self.subcommands_result[k] = v - result.main_args = self.args_result - result.options = self.options_result - result.subcommands = self.subcommands_result - result.unpack() - if argv.message_cache: - command_manager.record(argv.token, result) + if isinstance(exception, FuzzyMatchSuccess): + result.output = str(exception) + + if self.default_opt_result: + handle_opt_default(self.default_opt_result, self.options_result) + if self.default_sub_result: + for k, v in self.default_sub_result.items(): + if k not in self.subcommands_result: + self.subcommands_result[k] = v + result.main_args = self.args_result + result.options = self.options_result + result.subcommands = self.subcommands_result + result.unpack() + if not fail and argv.message_cache: + command_manager.record(argv.token, result) self.reset() return result # type: ignore diff --git a/src/arclet/alconna/_internal/_argv.py b/src/arclet/alconna/_internal/_argv.py index 62cf5ee4..1ae275e8 100644 --- a/src/arclet/alconna/_internal/_argv.py +++ b/src/arclet/alconna/_internal/_argv.py @@ -1,14 +1,12 @@ from __future__ import annotations -from collections import deque from dataclasses import InitVar, dataclass, field, fields from typing import Any, Callable, ClassVar, Generic, Iterable, Literal, TYPE_CHECKING from typing_extensions import Self from tarina import lang, split, split_once -from ..args import Arg -from ..base import Option, Subcommand +from ..base import Option from ..config import Namespace, config from ..constraint import ARGV_OVERRIDES from ..exceptions import NullMessage @@ -51,11 +49,10 @@ class Argv(Generic[TDC]): context_style: Literal["bracket", "parentheses"] | None = field(init=False) "命令上下文插值的风格,None 为关闭,bracket 为 {...},parentheses 为 $(...)" - current_node: Arg | Subcommand | Option | None = field(init=False) - """当前节点""" current_index: int = field(init=False) """当前数据的索引""" - stack_params: ChainMap["SubAnalyser | Option"] = field(init=False, default_factory=ChainMap) + stack_params: ChainMap[SubAnalyser | Option] = field(init=False, default_factory=lambda: ChainMap()) + error: Exception | None = field(init=False) ndata: int = field(init=False) """原始数据的长度""" bak_data: list[str | Any] = field(init=False) @@ -67,8 +64,6 @@ class Argv(Generic[TDC]): origin: TDC = field(init=False) """原始命令""" context: dict[str, Any] = field(init=False, default_factory=dict) - # special: dict[str, str] = field(init=False, default_factory=dict) - # completion_names: set[str] = field(init=False, default_factory=set) _sep: str | None = field(init=False) _cache: ClassVar[dict[type, dict[str, Any]]] = {} @@ -91,13 +86,6 @@ def compile(self, meta: CommandMeta): self.message_cache = self.namespace.enable_message_cache self.filter_crlf = not meta.keep_crlf self.context_style = meta.context_style - # self.special = {} - # self.special.update( - # [(i, "help") for i in self.namespace.builtin_option_name["help"]] - # + [(i, "completion") for i in self.namespace.builtin_option_name["completion"]] - # + [(i, "shortcut") for i in self.namespace.builtin_option_name["shortcut"]] - # ) - # self.completion_names = self.namespace.builtin_option_name["completion"] def reset(self): """重置命令行参数""" @@ -105,11 +93,11 @@ def reset(self): self.ndata = 0 self.bak_data = [] self.raw_data = [] - self.stack_params.maps = [] + self.error = None + self.stack_params.stack = [] self.token = 0 self.origin = "None" # type: ignore self._sep = None - self.current_node = None @staticmethod def generate_token(data: list) -> int: @@ -190,12 +178,11 @@ def addon(self, data: Iterable[str | Any], merge_str: bool = True) -> Self: self.token = self.generate_token(self.raw_data) return self - def next(self, separate: str | None = None, move: bool = True) -> tuple[str | Any, bool]: + def next(self, separate: str | None = None) -> tuple[str | Any, bool]: """获取解析需要的下个数据 Args: separate (str | None, optional): 分隔符. - move (bool, optional): 是否移动指针. Returns: tuple[str | Any, bool]: 下个数据, 是否是字符串. @@ -208,15 +195,13 @@ def next(self, separate: str | None = None, move: bool = True) -> tuple[str | An _current_data = self.raw_data[self.current_index] if _current_data.__class__ is str: _text, _rest_text = split_once(_current_data, separate, self.filter_crlf) # type: ignore - if move: - if _rest_text: - self._sep = separate - self.raw_data[self.current_index] = _rest_text - else: - self.current_index += 1 + if _rest_text: + self._sep = separate + self.raw_data[self.current_index] = _rest_text + else: + self.current_index += 1 return _text, True - if move: - self.current_index += 1 + self.current_index += 1 return _current_data, False def rollback(self, data: str | Any, replace: bool = False): diff --git a/src/arclet/alconna/_internal/_handlers.py b/src/arclet/alconna/_internal/_handlers.py index 69dd95bc..8889ed4e 100644 --- a/src/arclet/alconna/_internal/_handlers.py +++ b/src/arclet/alconna/_internal/_handlers.py @@ -8,8 +8,7 @@ from ..action import Action from ..args import Arg, Args -from ..base import Option, Subcommand, Header, SPECIAL_OPTIONS -from ..completion import Prompt, comp_ctx +from ..base import Option, Header from ..config import config from ..exceptions import ( AlconnaException, @@ -17,8 +16,7 @@ FuzzyMatchSuccess, InvalidHeader, InvalidParam, - PauseTriggered, - SpecialOptionTriggered, + PauseTriggered, ParamsUnmatched, ) from ..model import HeadResult, OptionResult from ..typing import KWBool, MultiKeyWordVar, MultiVar, _AllParamPattern, _StrMulti @@ -26,7 +24,7 @@ from ._util import levenshtein if TYPE_CHECKING: - from ._analyser import Analyser, SubAnalyser + from ._analyser import SubAnalyser from ._argv import Argv pat = re.compile("(?:-*no)?-*(?P.+)") @@ -47,10 +45,11 @@ def _context(argv: Argv, target: Arg[Any], _arg: str): try: return safe_eval(name, ctx) except NameError: - raise ArgumentMissing(target.field.get_missing_tips(lang.require("args", "missing").format(key=target.name))) + raise ArgumentMissing(target.field.get_missing_tips(lang.require("args", "missing").format(key=target.name)), target) except Exception as e: raise InvalidParam( - target.field.get_unmatch_tips(_arg, lang.require("nepattern", "context_error").format(target=target.name, expected=name)) + target.field.get_unmatch_tips(_arg, lang.require("nepattern", "context_error").format(target=target.name, expected=name)), + target ) @@ -71,13 +70,12 @@ def _validate(argv: Argv, target: Arg[Any], value: BasePattern[Any, Any, Any], r if res.flag == "error": if target.optional: return - raise InvalidParam(target.field.get_unmatch_tips(arg, res.error().args[0])) + raise InvalidParam(target.field.get_unmatch_tips(arg, res.error().args[0]), target) result[target.name] = res._value # noqa def step_varpos(argv: Argv, args: Args, slot: tuple[MultiVar, Arg], result: dict[str, Any]): value, arg = slot - argv.current_node = arg key = arg.name default_val = arg.field.default _result = [] @@ -85,9 +83,6 @@ def step_varpos(argv: Argv, args: Args, slot: tuple[MultiVar, Arg], result: dict count = 0 while argv.current_index != argv.ndata: may_arg, _str = argv.next(arg.separators) - # if _str and may_arg in argv.special: - # if argv.special[may_arg] not in argv.namespace.disable_builtin_options: - # raise SpecialOptionTriggered(argv.special[may_arg]) if not may_arg or (_str and may_arg in argv.stack_params and not argv.stack_params[may_arg].soft_keyword): argv.rollback(may_arg) break @@ -114,7 +109,7 @@ def step_varpos(argv: Argv, args: Args, slot: tuple[MultiVar, Arg], result: dict elif arg.optional: return else: - raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=key))) + raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=key)), arg) if isinstance(value, _StrMulti): result[key] = arg.separators[0].join(_result) else: @@ -123,16 +118,12 @@ def step_varpos(argv: Argv, args: Args, slot: tuple[MultiVar, Arg], result: dict def step_varkey(argv: Argv, slot: tuple[MultiKeyWordVar, Arg], result: dict[str, Any]): value, arg = slot - argv.current_node = arg name = arg.name default_val = arg.field.default _result = {} count = 0 while argv.current_index != argv.ndata: may_arg, _str = argv.next(arg.separators) - # if _str and may_arg in argv.special: - # if argv.special[may_arg] not in argv.namespace.disable_builtin_options: - # raise SpecialOptionTriggered(argv.special[may_arg]) if not may_arg or (_str and may_arg in argv.stack_params and not argv.stack_params[may_arg].soft_keyword) or not _str: argv.rollback(may_arg) break @@ -159,7 +150,7 @@ def step_varkey(argv: Argv, slot: tuple[MultiKeyWordVar, Arg], result: dict[str, elif arg.optional: return else: - raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=name))) + raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=name)), arg) result[name] = _result @@ -172,9 +163,6 @@ def step_keyword(argv: Argv, args: Args, result: dict[str, Any]): count = 0 while count < target: may_arg, _str = argv.next("".join(kwonly_seps)) - # if _str and may_arg in argv.special: - # if argv.special[may_arg] not in argv.namespace.disable_builtin_options: - # raise SpecialOptionTriggered(argv.special[may_arg]) if not may_arg or not _str: argv.rollback(may_arg) break @@ -193,11 +181,11 @@ def step_keyword(argv: Argv, args: Args, result: dict[str, Any]): break for arg in args.argument.keyword_only.values(): if arg.value.base.validate(may_arg).flag == "valid": # type: ignore - raise InvalidParam(lang.require("args", "key_missing").format(target=may_arg, key=arg.name)) + raise InvalidParam(lang.require("args", "key_missing").format(target=may_arg, key=arg.name), arg) for name in args.argument.keyword_only: if levenshtein(_key, name) >= argv.fuzzy_threshold: raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=name, target=_key)) - raise InvalidParam(lang.require("args", "key_not_found").format(name=_key)) + raise InvalidParam(lang.require("args", "key_not_found").format(name=_key), args) arg = args.argument.keyword_only[_key] value = arg.value.base # type: ignore if not _m_arg: @@ -215,11 +203,11 @@ def step_keyword(argv: Argv, args: Args, result: dict[str, Any]): if arg.field.default is not Empty: result[key] = arg.field.default elif not arg.optional: - raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=key))) + raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=key)), arg) def _raise(target: Arg, arg: Any, res: Any): - raise InvalidParam(target.field.get_unmatch_tips(arg, res.error().args[0])) + raise InvalidParam(target.field.get_unmatch_tips(arg, res.error().args[0]), arg) def analyse_args(argv: Argv, args: Args) -> dict[str, Any]: @@ -235,23 +223,19 @@ def analyse_args(argv: Argv, args: Args) -> dict[str, Any]: """ result = {} for arg in args.argument.normal: - argv.current_node = arg may_arg, _str = argv.next(arg.separators) - # if _str and may_arg in argv.special: - # if argv.special[may_arg] not in argv.namespace.disable_builtin_options: - # raise SpecialOptionTriggered(argv.special[may_arg]) if _str and may_arg in argv.stack_params and not argv.stack_params[may_arg].soft_keyword: argv.rollback(may_arg) if (de := arg.field.default) is not Empty: result[arg.name] = de elif not arg.optional: - raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=arg.name))) + raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=arg.name)), arg) continue if may_arg is None or (_str and not may_arg): if (de := arg.field.default) is not Empty: result[arg.name] = de elif not arg.optional: - raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=arg.name))) + raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=arg.name)), arg) continue value = arg.value if value.alias == "*": @@ -285,7 +269,6 @@ def analyse_args(argv: Argv, args: Args) -> dict[str, Any]: step_keyword(argv, args, result) for slot in args.argument.vars_keyword: step_varkey(argv, slot, result) - argv.current_node = None return result @@ -298,7 +281,6 @@ def handle_option(argv: Argv, opt: Option, name_validated: bool) -> tuple[str, O opt (Option): 目标 `Option` name_validated (bool): 是否已经验证过名称 """ - argv.current_node = opt _cnt = 0 error = True if not name_validated: @@ -319,11 +301,11 @@ def handle_option(argv: Argv, opt: Option, name_validated: bool) -> tuple[str, O if error: argv.rollback(name) if not argv.fuzzy_match: - raise InvalidParam(lang.require("option", "name_error").format(source=opt.dest, target=name)) + raise InvalidParam(lang.require("option", "name_error").format(source=opt.dest, target=name), opt) for al in opt.aliases: if levenshtein(name, al) >= argv.fuzzy_threshold: raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=al, target=name)) - raise InvalidParam(lang.require("option", "name_error").format(source=opt.dest, target=name)) + raise InvalidParam(lang.require("option", "name_error").format(source=opt.dest, target=name), opt) name = opt.dest if opt.nargs: return name, OptionResult(None, analyse_args(argv, opt.args)) @@ -379,7 +361,7 @@ def analyse_compact_params(analyser: SubAnalyser, argv: Argv, prefix: str): argv (Argv): 命令行参数 prefix (str): 参数前缀 """ - argv.rollback(prefix) + exc = None for param in analyser.compact_params: _data, _index = argv.data_set() try: @@ -390,11 +372,11 @@ def analyse_compact_params(analyser: SubAnalyser, argv: Argv, prefix: str): sparam: SubAnalyser = param # type: ignore try: sparam.process(argv, False) - except (FuzzyMatchSuccess, PauseTriggered, SpecialOptionTriggered): + except (FuzzyMatchSuccess, PauseTriggered): sparam.result() raise - except InvalidParam: - if argv.current_node is sparam.command: + except InvalidParam as e: + if e.context_node is sparam.command: sparam.result() else: analyser.subcommands_result[sparam.command.dest] = sparam.result() @@ -407,10 +389,13 @@ def analyse_compact_params(analyser: SubAnalyser, argv: Argv, prefix: str): _data.clear() return True except InvalidParam as e: - if argv.current_node.__class__ is Arg: - raise e - argv.data_reset(_data, _index) + if e.context_node is not param: + exc = e + else: + argv.data_reset(_data, _index) else: + if exc and not argv.error: + argv.error = exc return False @@ -433,55 +418,55 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: str | None = None): seps (str, optional): 指定的分隔符. """ _text, _str = argv.next(seps) - if _str and _text: - # if _text in argv.special and argv.special[_text] not in argv.namespace.disable_builtin_options: - # if _text in argv.completion_names: - # argv.bak_data[argv.current_index] = argv.bak_data[argv.current_index].replace(_text, "") - # raise SpecialOptionTriggered(argv.special[_text]) - if _param := analyser.compile_params.get(_text): - if _param.__class__ is Option: - oparam: Option = _param # type: ignore + if _str and _text and (_param := analyser.compile_params.get(_text)): + if Option in _param.__class__.__mro__: + oparam: Option = _param # type: ignore + try: analyse_option(analyser, argv, oparam, True) - argv.current_node = None - return True - sparam: SubAnalyser = _param # type: ignore - if sparam.command.dest not in analyser.subcommands_result: - try: - sparam.process(argv) - except (FuzzyMatchSuccess, PauseTriggered, SpecialOptionTriggered): + except AlconnaException as e: + if not argv.error: + argv.error = e + return True + sparam: SubAnalyser = _param # type: ignore + if sparam.command.dest not in analyser.subcommands_result: + try: + sparam.process(argv) + except (FuzzyMatchSuccess, PauseTriggered): + sparam.result() + raise + except InvalidParam as e: + if e.context_node is sparam.command: sparam.result() - raise - except InvalidParam: - if argv.current_node is sparam.command: - sparam.result() - else: - analyser.subcommands_result[sparam.command.dest] = sparam.result() - raise - except AlconnaException: - analyser.subcommands_result[sparam.command.dest] = sparam.result() - raise else: analyser.subcommands_result[sparam.command.dest] = sparam.result() - argv.current_node = None - return True - elif not analyser.compact_params: - argv.rollback(_text) - elif analyse_compact_params(analyser, argv, _text): - argv.current_node = None + if not argv.error: + argv.error = e + except AlconnaException as e1: + analyser.subcommands_result[sparam.command.dest] = sparam.result() + if not argv.error: + argv.error = e1 + else: + analyser.subcommands_result[sparam.command.dest] = sparam.result() return True - else: - argv.rollback(_text) + argv.rollback(_text) + if _str and _text and analyser.compact_params and analyse_compact_params(analyser, argv, _text): + return True if analyser.command.nargs and not analyser.args_result: analyser.args_result = analyse_args(argv, analyser.self_args) if analyser.args_result: - argv.current_node = None return True + if _str and _text and _text in argv.stack_params.parents(): + return False if analyser.extra_allow: analyser.args_result.setdefault("$extra", []).append(_text) - argv.next(seps) + argv.next() return True - else: - return False + elif _str and _text and not argv.stack_params.stack: + if not argv.error: + argv.error = ParamsUnmatched(lang.require("analyser", "param_unmatched").format(target=_text)) + argv.next() + return True + return False def analyse_header(header: "Header", argv: Argv): @@ -528,108 +513,3 @@ def handle_head_fuzzy(header: Header, source: str, threshold: float): for ht in headers_text: if levenshtein(source, ht) >= threshold: return lang.require("fuzzy", "matched").format(target=source, source=ht) - - -# def handle_help(analyser: Analyser, argv: Argv): -# """处理帮助选项触发""" -# _help_param = [str(i) for i in argv.release(recover=True) if str(i) not in argv.special] -# output_manager.send( -# analyser.command.name, -# lambda: analyser.command.formatter.format_node(_help_param), -# ) -# return SpecialOptionTriggered("help") - - -# _args = Args["action?", "delete|list"]["name?", str]["command", str, "$"] -# -# -# def handle_shortcut(analyser: Analyser, argv: Argv): -# """处理快捷命令触发""" -# argv.next() -# try: -# opt_v = analyse_args(argv, _args, None) -# except SpecialOptionTriggered: -# return handle_completion(analyser, argv) -# try: -# if opt_v.get("action") == "list": -# data = analyser.command.get_shortcuts() -# output_manager.send(analyser.command.name, lambda: "\n".join(data)) -# else: -# if not opt_v.get("name"): -# raise ArgumentMissing(lang.require("shortcut", "name_require")) -# if opt_v.get("action") == "delete": -# msg = analyser.command.shortcut(opt_v["name"], delete=True) -# elif opt_v["command"] == "$": -# msg = analyser.command.shortcut(opt_v["name"], fuzzy=True) -# else: -# msg = analyser.command.shortcut(opt_v["name"], fuzzy=True, command=opt_v["command"]) -# output_manager.send(analyser.command.name, lambda: msg) -# except Exception as e: -# output_manager.send(analyser.command.name, lambda: str(e)) -# return SpecialOptionTriggered("shortcut") -# - -def _prompt_unit(analyser: Analyser, argv: Argv, trig: Arg): - if not (comp := trig.field.get_completion()): - return [Prompt(analyser.command.formatter.param(trig), False)] - if isinstance(comp, str): - return [Prompt(f"{trig.name}: {comp}", False)] - releases = argv.release(recover=True) - target = str(releases[-1]) or str(releases[-2]) - o = list(filter(lambda x: target in x, comp)) or comp - return [Prompt(f"{trig.name}: {i}", False, target) for i in o] - - -def _prompt_none(analyser: Analyser, argv: Argv, got: list[str]): - res: list[Prompt] = [] - if not analyser.args_result and analyser.self_args.argument: - unit = analyser.self_args.argument[0] - if not (comp := unit.field.get_completion()): - res.append(Prompt(analyser.command.formatter.param(unit), False)) - elif isinstance(comp, str): - res.append(Prompt(f"{unit.name}: {comp}", False)) - else: - res.extend(Prompt(f"{unit.name}: {i}", False) for i in comp) - for opt in analyser.command.options: - if isinstance(opt, SPECIAL_OPTIONS): - continue - if opt.dest not in got: - res.extend([Prompt(al) for al in opt.aliases] if isinstance(opt, Option) else [Prompt(opt.name)]) - return res - - -def prompt(analyser: Analyser, argv: Argv, trigger: str | None = None): - """获取补全列表""" - _trigger = trigger or argv.current_node - got = [*analyser.options_result.keys(), *analyser.subcommands_result.keys()] - if isinstance(_trigger, Arg): - return _prompt_unit(analyser, argv, _trigger) - elif isinstance(_trigger, Subcommand): - return [Prompt(i) for i in analyser.get_sub_analyser(_trigger).compile_params] # type: ignore - elif isinstance(_trigger, str): - res = list(filter(lambda x: _trigger in x, analyser.compile_params)) - if not res: - return [] - out = [i for i in res if i not in got] - return [Prompt(i, True, _trigger) for i in (out or res)] - releases = argv.release(recover=True) - target = str(releases[-1]) or str(releases[-2]) - if _res := list(filter(lambda x: target in x and target != x, analyser.compile_params)): - out = [i for i in _res if i not in got] - return [Prompt(i, True, target) for i in (out or _res)] - return _prompt_none(analyser, argv, got) - - -def handle_completion(analyser: Analyser, argv: Argv, trigger: str | None = None): - """处理补全选项触发""" - if res := prompt(analyser, argv, trigger): - if comp_ctx.get(None): - raise PauseTriggered(res, trigger, argv) - prompt_other = lang.require("completion", "prompt_other") - node = lang.require('completion', 'node') - node = f"{node}\n" if node else "" - print( - analyser.command.name, - lambda: f"{node}{prompt_other}" + f"\n{prompt_other}".join([i.text for i in res]), - ) - return SpecialOptionTriggered("completion") diff --git a/src/arclet/alconna/_internal/_util.py b/src/arclet/alconna/_internal/_util.py index 030e6429..83658c52 100644 --- a/src/arclet/alconna/_internal/_util.py +++ b/src/arclet/alconna/_internal/_util.py @@ -1,26 +1,34 @@ -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Optional T = TypeVar("T") class ChainMap(Generic[T]): - def __init__(self): - self.maps: "list[dict[str, T]]" = [] + def __init__(self, base: Optional[dict[str, T]] = None, *maps: dict[str, T]): + self.base: dict[str, T] = base or {} + self.stack: "list[dict[str, T]]" = list(maps) def enter(self, map: dict): - self.maps.insert(0, map) + self.stack.insert(0, map) def __contains__(self, item: str): - return any(item in m for m in self.maps) + return item in self.base or any(item in m for m in self.stack) def __getitem__(self, item: str) -> T: - for m in self.maps: + for m in self.stack: if item in m: return m[item] + if item in self.base: + return self.base[item] raise KeyError(item) + def parents(self): + if not self.stack: + return ChainMap() + return ChainMap(self.base, *self.stack[1:]) + def leave(self): - self.maps.pop(0) + self.stack.pop(0) def levenshtein(source: str, target: str) -> float: diff --git a/src/arclet/alconna/arparma.py b/src/arclet/alconna/arparma.py index 50ae0681..c4387f15 100644 --- a/src/arclet/alconna/arparma.py +++ b/src/arclet/alconna/arparma.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from functools import lru_cache from types import MappingProxyType -from typing import Any, Callable, ClassVar, Generic, TypeVar, cast, overload +from typing import Any, Callable, ClassVar, Generic, TypeVar, cast, overload, Literal from typing_extensions import Self from tarina import Empty, generic_isinstance, lang, safe_eval @@ -74,23 +74,43 @@ def __getitem__(self, item: type[T1]) -> _Query[T1]: def __call__(self, path: str) -> T | None: ... + @overload + def __call__(self, path: str, *, force_return: Literal[True]) -> T: + ... + @overload def __call__(self, path: str, default: D) -> T | D: ... - def __call__(self, path: str, default: D | None = None) -> T | D | None: + def __call__(self, path: str, default: D | None = None, *, force_return: bool = False) -> T | D | None: """查询 `Arparma` 中的数据 Args: path (str): 要查询的路径 default (T | None, optional): 如果查询失败, 则返回该值 + force_return (bool, optional): 是否强制返回值, 默认为 False; 如果为 True, 则查询失败时抛出异常 """ source, endpoint = self.source.__require__(path.split(".")) if source is None: + if force_return: + raise KeyError(path) return default if isinstance(source, dict): - return source.get(endpoint, default) if endpoint else MappingProxyType(source) # type: ignore - return getattr(source, endpoint, default) if endpoint else source # type: ignore + if endpoint: + if endpoint in source: + return source[endpoint] + if force_return: + raise KeyError(path) + return default + return MappingProxyType(source) # type: ignore + if endpoint: + try: + return getattr(source, endpoint) + except AttributeError: + if force_return: + raise + return default + return source # type: ignore class Arparma(Generic[TDC]): @@ -106,11 +126,14 @@ class Arparma(Generic[TDC]): other_args (dict[str, Any]): 其他参数匹配结果 options (dict[str, OptionResult]): 选项匹配结果 subcommands (dict[str, SubcommandResult]): 子命令匹配结果 + context (dict[str, Any]): 上下文 + output (str | None): 输出信息 """ header_match: HeadResult options: dict[str, OptionResult] subcommands: dict[str, SubcommandResult] + output: str | None def __init__( self, @@ -149,6 +172,7 @@ def __init__( self.options = options or {} self.subcommands = subcommands or {} self.context = ctx or {} + self.output = None _additional: ClassVar[dict[str, Callable[[], Any]]] = {} query = _Query[Any]() @@ -234,8 +258,6 @@ def execute(self, behaviors: list[ArparmaBehavior] | None = None) -> Self: """ if not behaviors: return self - for b in behaviors: - b.before_operate(self) for b in behaviors: try: b.operate(self) @@ -391,29 +413,8 @@ class ArparmaBehavior(metaclass=ABCMeta): requires (list[ArparmaBehavior]): 该行为器所依赖的行为器 """ - record: dict[int, dict[str, tuple[Any, Any]]] = field(default_factory=dict, init=False, repr=False, hash=False) requires: list[ArparmaBehavior] = field(init=False, hash=False, repr=False) - def before_operate(self, interface: Arparma): - """在操作前调用, 用于准备数据""" - if not self.record: - return - if not (_record := self.record.get(interface.token)): - return - for path, (past, current) in _record.items(): - source, end = interface.__require__(path.split(".")) - if source is None: - continue - if isinstance(source, dict): - if past != Empty: - source[end] = past - elif source.get(end, Empty) != current: - source.pop(end) - elif past != Empty: - setattr(source, end, past) - elif getattr(source, end, Empty) != current: - delattr(source, end) - _record.clear() @abstractmethod def operate(self, interface: Arparma): @@ -430,12 +431,9 @@ def update(self, interface: Arparma, path: str, value: Any): """ def _update(tkn, src, pth, ep, val): - _record = self.record.setdefault(tkn, {}) if isinstance(src, dict): - _record[pth] = (src.get(ep, Empty), val) src[ep] = val else: - _record[pth] = (getattr(src, ep, Empty), val) setattr(src, ep, val) source, end = interface.__require__(path.split(".")) diff --git a/src/arclet/alconna/completion.py b/src/arclet/alconna/completion.py index 88565ef6..5b25706d 100644 --- a/src/arclet/alconna/completion.py +++ b/src/arclet/alconna/completion.py @@ -6,10 +6,13 @@ from tarina import ContextModel, lang from .arparma import Arparma -from .exceptions import InvalidParam, ParamsUnmatched, PauseTriggered, SpecialOptionTriggered +from .exceptions import InvalidParam, ParamsUnmatched, PauseTriggered from .manager import command_manager +from .base import Subcommand, SPECIAL_OPTIONS, Option +from .args import Arg if TYPE_CHECKING: + from .argv import Argv from .core import Alconna @@ -124,7 +127,7 @@ def enter(self, content: list | None = None) -> EnterResult: return EnterResult(exception=ValueError(lang.require("completion", "prompt_unavailable"))) if prompt.removal_prefix: argv.bak_data[-1] = argv.bak_data[-1][: -len(prompt.removal_prefix)] - argv.next(move=True) + argv.next() input_ = [prompt.text] if isinstance(self.trigger, InvalidParam): argv.raw_data = argv.bak_data[: max(self.current_index, 1)] @@ -144,7 +147,7 @@ def enter(self, content: list | None = None) -> EnterResult: self.exit() return EnterResult(res) if exc := self.source.process(argv): - if isinstance(exc, (ParamsUnmatched, SpecialOptionTriggered)): + if isinstance(exc, ParamsUnmatched): self.exit() return EnterResult(self.source.export(argv, True, exc)) if isinstance(exc, PauseTriggered): @@ -212,8 +215,8 @@ def fresh(self, exc: PauseTriggered): """ self.clear() self.push(*exc.args[0]) - self.trigger = exc.args[1] - argv = exc.args[2] + self.trigger = exc.context_node + argv = exc.argv self.raw_data = argv.raw_data self.bak_data = argv.bak_data self.current_index = argv.current_index @@ -221,3 +224,53 @@ def fresh(self, exc: PauseTriggered): comp_ctx: ContextModel[CompSession] = ContextModel("comp_ctx") + + +def _prompt_unit(command: Alconna, argv: Argv, trig: Arg): + if not (comp := trig.field.get_completion()): + return [Prompt(command.formatter.param(trig), False)] + if isinstance(comp, str): + return [Prompt(f"{trig.name}: {comp}", False)] + releases = argv.release(recover=True) + target = str(releases[-1]) or str(releases[-2]) + o = list(filter(lambda x: target in x, comp)) or comp + return [Prompt(f"{trig.name}: {i}", False, target) for i in o] + + +def _prompt_none(command: Alconna, args_got: list[str], opts_got: list[str]): + res: list[Prompt] = [] + if unit := next((arg for arg in command.args if arg.name not in args_got), None): + if not (comp := unit.field.get_completion()): + res.append(Prompt(command.formatter.param(unit), False)) + elif isinstance(comp, str): + res.append(Prompt(f"{unit.name}: {comp}", False)) + else: + res.extend(Prompt(f"{unit.name}: {i}", False) for i in comp) + for opt in command.options: + if isinstance(opt, SPECIAL_OPTIONS): + continue + if opt.dest not in opts_got: + res.extend([Prompt(al) for al in opt.aliases] if isinstance(opt, Option) else [Prompt(opt.name)]) + return res + + +def prompt(command: Alconna, argv: Argv, args_got: list[str], opts_got: list[str], trigger: str | Arg | Subcommand | None = None): + """获取补全列表""" + if isinstance(trigger, Arg): + return _prompt_unit(command, argv, trigger) + elif isinstance(trigger, Subcommand): + return [Prompt(i) for i in argv.stack_params.stack[-1]] + elif isinstance(trigger, str): + res = list(filter(lambda x: trigger in x, argv.stack_params.base)) + if not res: + return [] + out = [i for i in res if i not in opts_got] + return [Prompt(i, True, trigger) for i in (out or res)] + releases = argv.release(recover=True) + target = str(releases[-1]) + if isinstance(releases[-1], str) and releases[-1] in command.namespace_config.builtin_option_name["completion"]: + target = str(releases[-2]) + if _res := list(filter(lambda x: target in x and target != x, argv.stack_params.base)): + out = [i for i in _res if i not in opts_got] + return [Prompt(i, True, target) for i in (out or _res)] + return _prompt_none(command, args_got, opts_got) diff --git a/src/arclet/alconna/core.py b/src/arclet/alconna/core.py index 74434ce0..d9ec3aca 100644 --- a/src/arclet/alconna/core.py +++ b/src/arclet/alconna/core.py @@ -1,7 +1,6 @@ """Alconna 主体""" from __future__ import annotations -import re import sys from dataclasses import dataclass, field from pathlib import Path @@ -10,24 +9,26 @@ from weakref import WeakSet from nepattern import TPattern -from tarina import init_spec, lang +from tarina import init_spec, lang, Empty +from .model import OptionResult from ._internal._analyser import Analyser, TCompile from ._internal._handlers import handle_head_fuzzy, analyse_header from ._internal._shortcut import shortcut as _shortcut from .args import Arg, Args from .arparma import Arparma, ArparmaBehavior, requirement_handler -from .base import Completion, Help, Option, Shortcut, Subcommand, Header +from .base import Completion, Help, Option, Shortcut, Subcommand, Header, SPECIAL_OPTIONS from .config import Namespace, config from .constraint import SHORTCUT_ARGS, SHORTCUT_REGEX_MATCH, SHORTCUT_REST, SHORTCUT_TRIGGER from .exceptions import ( AlconnaException, + AnalyseException, ExecuteFailed, FuzzyMatchSuccess, InvalidHeader, PauseTriggered, - SpecialOptionTriggered, ) +from .completion import prompt, comp_ctx from .formatter import TextFormatter from .manager import ShortcutArgs, command_manager from .typing import TDC, CommandMeta, InnerShortcutArgs, ShortcutRegWrapper @@ -45,21 +46,71 @@ def handle_argv(): return head -def add_builtin_options(options: list[Option | Subcommand], ns: Namespace) -> None: +def add_builtin_options(options: list[Option | Subcommand], cmd: Alconna, ns: Namespace) -> None: if "help" not in ns.disable_builtin_options: options.append(Help("|".join(ns.builtin_option_name["help"]), dest="$help", help_text=lang.require("builtin", "option_help"))) # noqa: E501 + + @cmd.route("$help") + def _(command: Alconna, arp: Arparma): + argv = command_manager.resolve(cmd) + _help_param = [str(i) for i in argv.release(recover=True) if str(i) not in ns.builtin_option_name["help"]] + arp.output = command.formatter.format_node(_help_param) + return True + if "shortcut" not in ns.disable_builtin_options: options.append( Shortcut( "|".join(ns.builtin_option_name["shortcut"]), - Args["action?", "delete|list"]["name?", str]["command", str, "$"], + Args["action?", "delete|list"]["name?", str]["command?", str], dest="$shortcut", help_text=lang.require("builtin", "option_shortcut"), ) ) + + @cmd.route("$shortcut") + def _(command: Alconna, arp: Arparma): + res = arp.query[OptionResult]("$shortcut", force_return=True) + if res.args.get("action") == "list": + data = command.get_shortcuts() + arp.output = "\n".join(data) + return True + if not res.args.get("name"): + raise ValueError(lang.require("shortcut", "name_require")) + if res.args.get("action") == "delete": + msg = command.shortcut(res.args["name"], delete=True) + else: + msg = command.shortcut(res.args["name"], fuzzy=True, command=res.args.get("command")) + arp.output = msg + return True + if "completion" not in ns.disable_builtin_options: options.append(Completion("|".join(ns.builtin_option_name["completion"]), dest="$completion", help_text=lang.require("builtin", "option_completion"))) # noqa: E501 + @cmd.route("$completion") + def _(command: Alconna, arp: Arparma): + argv = command_manager.resolve(cmd) + rest = argv.release() + trigger = None + if rest and isinstance(rest[-1], str) and rest[-1] in ns.builtin_option_name["completion"]: + argv.bak_data[-1] = argv.bak_data[-1][: -len(rest[-1])].rstrip() + trigger = rest[-2] + elif isinstance(arp.error_info, AnalyseException): + trigger = arp.error_info.context_node + if res := prompt( + command, + argv, + list(arp.main_args.keys()), + [*arp.options.keys(), *arp.subcommands.keys()], + trigger + ): + if comp_ctx.get(None): + raise PauseTriggered(res, trigger, argv) + prompt_other = lang.require("completion", "prompt_other") + node = lang.require('completion', 'node') + node = f"{node}\n" if node else "" + arp.output = f"{node}{prompt_other}" + f"\n{prompt_other}".join([i.text for i in res]) + return True + @dataclass(init=True, unsafe_hash=True) class ArparmaExecutor(Generic[T]): @@ -89,6 +140,21 @@ def result(self) -> T: raise ExecuteFailed(e) from e +class Router: + def __init__(self): + self._routes = {} + + def execute(self, cmd: Alconna, arp: Arparma): + for route, target in self._routes.items(): + if arp.query(route, Empty) is not Empty: + try: + res = target(cmd, arp) + if res is True: + return + except Exception as e: + return e + + class Alconna(Subcommand): """ 更加精确的命令解析 @@ -161,6 +227,7 @@ def __init__( self.command = next(i for i in args if not isinstance(i, (list, Option, Subcommand, Args, Arg, CommandMeta, ArparmaBehavior))) except StopIteration: self.command = "" if self.prefixes else handle_argv() + self.router = Router() self.namespace = ns_config.name self.formatter = (formatter_type or ns_config.formatter_type or TextFormatter)() self.meta = meta or next((i for i in args if isinstance(i, CommandMeta)), CommandMeta()) @@ -172,7 +239,7 @@ def __init__( self.meta.context_style = self.meta.context_style or ns_config.context_style self._header = Header.generate(self.command, self.prefixes, self.meta.compact) options = [i for i in args if isinstance(i, (Option, Subcommand))] - add_builtin_options(options, ns_config) + add_builtin_options(options, self, ns_config) name = next(iter(self._header.content), self.command or self.prefixes[0]) self.path = f"{self.namespace}::{name}" _args = sum((i for i in args if isinstance(i, (Args, Arg))), Args()) @@ -211,8 +278,8 @@ def reset_namespace(self, namespace: Namespace | str, header: bool = True) -> Se self.dest = name self.path = f"{self.namespace}::{name}" self.aliases = frozenset((name,)) - self.options = [opt for opt in self.options if not isinstance(opt, (Help, Completion, Shortcut))] - add_builtin_options(self.options, namespace) + self.options = [opt for opt in self.options if not isinstance(opt, SPECIAL_OPTIONS)] + add_builtin_options(self.options, self, namespace) self.meta.fuzzy_match = namespace.fuzzy_match or self.meta.fuzzy_match self.meta.raise_exception = namespace.raise_exception or self.meta.raise_exception return self @@ -237,6 +304,12 @@ def _get_shortcuts(self): """返回该命令注册的快捷命令""" return command_manager.get_shortcut(self) + def route(self, path: str): + def wrapper(target: Callable[[Alconna, Arparma], Any]): + self.router._routes[path] = target + return target + return wrapper + @overload def shortcut(self, key: str | TPattern, args: ShortcutArgs) -> str: """操作快捷命令 @@ -307,6 +380,8 @@ def shortcut(self, key: str | TPattern, args: ShortcutArgs | None = None, delete if kwargs and not args: kwargs["args"] = kwargs.pop("arguments", None) kwargs = {k: v for k, v in kwargs.items() if v is not None} + if kwargs.get("command") == "$": + del kwargs["command"] args = cast(ShortcutArgs, kwargs) if args is not None: return command_manager.add_shortcut(self, key, args) @@ -356,7 +431,7 @@ def _parse(self, message: TDC, ctx: dict[str, Any] | None = None) -> Arparma[TDC if not (exc := analyser.process(argv)): return analyser.export(argv) if isinstance(exc, InvalidHeader): - trigger = exc.args[1] + trigger = exc.context_node if trigger.__class__ is str and trigger: argv.context[SHORTCUT_TRIGGER] = trigger try: @@ -371,14 +446,11 @@ def _parse(self, message: TDC, ctx: dict[str, Any] | None = None) -> Arparma[TDC return analyser.export(argv) except ValueError: if argv.fuzzy_match and (res := handle_head_fuzzy(self._header, trigger, argv.fuzzy_threshold)): - output_manager.send(self.name, lambda: res) exc = FuzzyMatchSuccess(res) except AlconnaException as e: exc = e if isinstance(exc, PauseTriggered): raise exc - if self.meta.raise_exception and not isinstance(exc, (FuzzyMatchSuccess, SpecialOptionTriggered)): - raise exc return analyser.export(argv, True, exc) def parse(self, message: TDC, ctx: dict[str, Any] | None = None) -> Arparma[TDC]: @@ -395,9 +467,11 @@ def parse(self, message: TDC, ctx: dict[str, Any] | None = None) -> Arparma[TDC] arp = self._parse(message, ctx) if arp.matched: arp = arp.execute(self.behaviors) - if self._executors: - for ext in self._executors: - self._executors[ext] = arp.call(ext.target) + if arp.matched and self._executors: + for ext in self._executors: + self._executors[ext] = arp.call(ext.target) + if err := self.router.execute(self, arp): + return arp.fail(err) return arp def bind(self, active: bool = True): @@ -448,12 +522,16 @@ def _calc_hash(self): def __call__(self, *args): if args: - return self.parse(list(args)) # type: ignore - head = handle_argv() - argv = [(f"\"{arg}\"" if any(arg.count(sep) for sep in self.separators) else arg) for arg in sys.argv[1:]] - if head != self.command: - return self.parse(argv) # type: ignore - return self.parse([head, *argv]) # type: ignore + res = self.parse(list(args)) # type: ignore + else: + head = handle_argv() + argv = [(f"\"{arg}\"" if any(arg.count(sep) for sep in self.separators) else arg) for arg in sys.argv[1:]] + if head != self.command: + return self.parse(argv) # type: ignore + res = self.parse([head, *argv]) # type: ignore + if res.output: + print(res.output) + return res @property def header_display(self): diff --git a/src/arclet/alconna/exceptions.py b/src/arclet/alconna/exceptions.py index e4ebf5d7..0394a764 100644 --- a/src/arclet/alconna/exceptions.py +++ b/src/arclet/alconna/exceptions.py @@ -4,25 +4,24 @@ class AlconnaException(Exception): """Alconna 异常基类""" - -class ParamsUnmatched(AlconnaException): - """一个传入参数没有被选项或Args匹配""" +class ExecuteFailed(AlconnaException): + """执行失败""" -class InvalidParam(AlconnaException): - """传入参数验证失败""" +class ExceedMaxCount(AlconnaException): + """注册的命令数量超过最大长度""" -class InvalidHeader(AlconnaException): - """传入的消息头部无效""" +class BehaveCancelled(AlconnaException): + """行为执行被停止""" -class ArgumentMissing(AlconnaException): - """组件内的 Args 参数未能解析到任何内容""" +class OutBoundsBehave(AlconnaException): + """越界行为""" -class InvalidArgs(AlconnaException): - """构造 alconna 时某个传入的参数不正确""" +class FuzzyMatchSuccess(AlconnaException): + """模糊匹配成功""" class NullMessage(AlconnaException): @@ -33,29 +32,37 @@ class UnexpectedElement(AlconnaException): """给出的消息含有不期望的元素""" -class ExecuteFailed(AlconnaException): - """执行失败""" +class AnalyseException(AlconnaException): + """Alconna Analyse 异常基类""" + def __init__(self, msg, context_node=None): + super().__init__(msg) + self.context_node = context_node -class ExceedMaxCount(AlconnaException): - """注册的命令数量超过最大长度""" +class ParamsUnmatched(AnalyseException): + """一个传入参数没有被选项或Args匹配""" -class BehaveCancelled(AlconnaException): - """行为执行被停止""" +class InvalidParam(AnalyseException): + """传入参数验证失败""" -class OutBoundsBehave(AlconnaException): - """越界行为""" +class InvalidHeader(AnalyseException): + """传入的消息头部无效""" -class FuzzyMatchSuccess(AlconnaException): - """模糊匹配成功""" +class ArgumentMissing(AnalyseException): + """组件内的 Args 参数未能解析到任何内容""" -class PauseTriggered(AlconnaException): - """解析状态保存触发""" +class InvalidArgs(AnalyseException): + """构造 alconna 时某个传入的参数不正确""" + + +class PauseTriggered(AnalyseException): + """解析状态保存触发""" -class SpecialOptionTriggered(AlconnaException): - """内置选项解析触发""" + def __init__(self, msg, context_node, argv): + super().__init__(msg, context_node) + self.argv = argv diff --git a/src/arclet/alconna/formatter.py b/src/arclet/alconna/formatter.py index 326df849..8538453a 100644 --- a/src/arclet/alconna/formatter.py +++ b/src/arclet/alconna/formatter.py @@ -22,7 +22,7 @@ def ensure_node(targets: list[str], options: list[Option | Subcommand], record: if isinstance(opt, Option) and pf in opt.aliases: record.append(pf) return opt - if isinstance(opt, Subcommand) and pf == opt.name: + if isinstance(opt, Subcommand) and pf in opt.aliases: record.append(pf) if not targets: return opt @@ -110,11 +110,11 @@ def _handle(trace: Trace): prefix += trace.separators[0] + trace.separators[0].join(rec[:-1]) if isinstance(end, Option): return self.format( - Trace({"name": prefix + trace.separators[0] + "│".join(end.aliases), "description": end.help_text, 'example': None, 'usage': None}, end.args, end.separators, [], {}) # noqa: E501 + Trace({"name": prefix + trace.separators[0] + "│".join(sorted(end.aliases)), "description": end.help_text, 'example': None, 'usage': None}, end.args, end.separators, [], {}) # noqa: E501 ) if isinstance(end, Subcommand): return self.format( - Trace({"name": prefix + trace.separators[0] + "│".join(end.aliases), "description": end.help_text, 'example': None, 'usage': None}, end.args, end.separators, end.options, {}) # noqa: E501 + Trace({"name": prefix + trace.separators[0] + "│".join(sorted(end.aliases)), "description": end.help_text, 'example': None, 'usage': None}, end.args, end.separators, end.options, {}) # noqa: E501 ) return self.format(trace) diff --git a/tests/analyser_test.py b/tests/analyser_test.py index 7978c38e..675bef72 100644 --- a/tests/analyser_test.py +++ b/tests/analyser_test.py @@ -33,7 +33,10 @@ def at(user_id: Union[int, str]): def gen_unit(type_: str): return BasePattern( - mode=MatchMode.VALUE_OPERATE, origin=Segment, converter=lambda _, seg: seg if seg.type == type_ else None, alias=type_, + mode=MatchMode.VALUE_OPERATE, + origin=Segment, + converter=lambda _, seg: seg if seg.type == type_ else None, + alias=type_, ) diff --git a/tests/args_test.py b/tests/args_test.py index cfac4df5..4b5ba6cf 100644 --- a/tests/args_test.py +++ b/tests/args_test.py @@ -169,8 +169,7 @@ def test( d: int = 1, e: bool = False, **kwargs: str, - ): - ... + ): ... arg16, _ = Args.from_callable(test) assert len(arg16.argument) == 7 @@ -247,7 +246,9 @@ def test_multi_multi(): def test_contextval(): arg21 = Args["foo", str] assert analyse_args(arg21, ["$(bar)"], context_style="parentheses", bar="baz") == {"foo": "baz"} - assert analyse_args(arg21, ["{bar}"], context_style="parentheses", raise_exception=False, bar="baz") != {"foo": "baz"} + assert analyse_args(arg21, ["{bar}"], context_style="parentheses", raise_exception=False, bar="baz") != { + "foo": "baz" + } assert analyse_args(arg21, ["{bar}"], context_style="bracket", bar="baz") == {"foo": "baz"} assert analyse_args(arg21, ["$(bar)"], context_style="bracket", raise_exception=False, bar="baz") != {"foo": "baz"} @@ -265,7 +266,9 @@ class B: arg21_1 = Args["foo", int] assert analyse_args(arg21_1, ["$(bar)"], context_style="parentheses", bar=123) == {"foo": 123} assert analyse_args(arg21_1, ["$(bar)"], context_style="parentheses", bar="123") == {"foo": 123} - assert analyse_args(arg21_1, ["$(bar)"], context_style="parentheses", raise_exception=False, bar="baz") != {"foo": 123} + assert analyse_args(arg21_1, ["$(bar)"], context_style="parentheses", raise_exception=False, bar="baz") != { + "foo": 123 + } if __name__ == "__main__": diff --git a/tests/components_test.py b/tests/components_test.py index 31c56704..442d3ac6 100644 --- a/tests/components_test.py +++ b/tests/components_test.py @@ -27,7 +27,9 @@ def operate(cls, interface: "Arparma"): assert com1.parse("comp1 --foo 1 --baz 2").matched is False com1.behaviors.clear() - com1.behaviors.append(conflict("foo.bar", "baz.qux", source_limiter=lambda x: x == 2, target_limiter=lambda x: x == 1)) + com1.behaviors.append( + conflict("foo.bar", "baz.qux", source_limiter=lambda x: x == 2, target_limiter=lambda x: x == 1) + ) assert com1.parse("comp1 --foo 1").matched assert com1.parse("comp1 --baz 2").matched @@ -42,12 +44,7 @@ def operate(cls, interface: "Arparma"): assert com1.parse("comp1 --foo 1 --baz 2").matched assert com1.parse("comp1 --foo 1 --baz 1").matched is False - com1_1 = Alconna( - "comp1_1", - Option("-1", dest="one"), - Option("-2", dest="two"), - Option("-3", dest="three") - ) + com1_1 = Alconna("comp1_1", Option("-1", dest="one"), Option("-2", dest="two"), Option("-3", dest="three")) com1_1.behaviors.append(conflict("one", "two")) com1_1.behaviors.append(conflict("two", "three")) @@ -56,21 +53,21 @@ def operate(cls, interface: "Arparma"): assert com1_1.parse("comp1_1 -1 -3").matched -def test_output(): - print("") - output_manager.set_action(lambda x: {"bar": f"{x}!"}, "foo") - output_manager.set(lambda: "123", "foo") - assert output_manager.send("foo") == {"bar": "123!"} - assert output_manager.send("foo", lambda: "321") == {"bar": "321!"} - - com5 = Alconna("comp5", Args["foo", int], Option("--bar", Args["bar", str])) - output_manager.set_action(lambda x: x, "comp5") - with output_manager.capture("comp5") as output: - com5.parse("comp5 --help") - assert output.get("output") - print("") - print(output.get("output")) - print(output.get("output")) # capture will clear when exit context +# def test_output(): +# print("") +# output_manager.set_action(lambda x: {"bar": f"{x}!"}, "foo") +# output_manager.set(lambda: "123", "foo") +# assert output_manager.send("foo") == {"bar": "123!"} +# assert output_manager.send("foo", lambda: "321") == {"bar": "321!"} +# +# com5 = Alconna("comp5", Args["foo", int], Option("--bar", Args["bar", str])) +# output_manager.set_action(lambda x: x, "comp5") +# with output_manager.capture("comp5") as output: +# com5.parse("comp5 --help") +# assert output.get("output") +# print("") +# print(output.get("output")) +# print(output.get("output")) # capture will clear when exit context if __name__ == "__main__": diff --git a/tests/core_test.py b/tests/core_test.py index c372a040..e517877b 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -336,19 +336,22 @@ def test_wildcard(): 5.0, "dsdf", ] - assert alc13.parse( - """core13 + assert ( + alc13.parse( + """core13 import foo def test(): print("Hello, World!")""" - ).foo == [ - """\ + ).foo + == [ + """\ import foo def test(): print("Hello, World!")""" - ] + ] + ) alc13_1 = Alconna("core13_1", Args["foo", AllParam(str)]) assert alc13_1.parse(["core13_1 abc def gh", 123, 5.0, "dsdf"]).foo == [ @@ -376,42 +379,33 @@ def test_alconna_group(): def test_fuzzy(): - from arclet.alconna import output_manager - alc15 = Alconna("!core15", Args["foo", str], meta=CommandMeta(fuzzy_match=True)) - with output_manager.capture("!core15") as cap: - output_manager.set_action(lambda x: x, "!core15") - res = alc15.parse("core15 foo bar") - assert res.matched is False - assert cap["output"] == '无法解析 "core15"。您想要输入的是不是 "!core15" ?' - with output_manager.capture("!core15") as cap: - output_manager.set_action(lambda x: x, "!core15") - res1 = alc15.parse([1, "core15", "foo", "bar"]) - assert res1.matched is False - assert cap["output"] == '无法解析 "1 core15"。您想要输入的是不是 "!core15" ?' + res = alc15.parse("core15 foo bar") + assert res.matched is False + assert res.output == '无法解析 "core15"。您想要输入的是不是 "!core15" ?' + res1 = alc15.parse([1, "core15", "foo", "bar"]) + assert res1.matched is False + assert res1.output == '无法解析 "1 core15"。您想要输入的是不是 "!core15" ?' alc15_1 = Alconna(["/"], "core15_1", meta=CommandMeta(fuzzy_match=True)) - with output_manager.capture("/core15_1") as cap: - output_manager.set_action(lambda x: x, "/core15_1") - res2 = alc15_1.parse("core15_1") - assert res2.matched is False - assert cap["output"] == '无法解析 "core15_1"。您想要输入的是不是 "/core15_1" ?' - with output_manager.capture("/core15_1") as cap: - output_manager.set_action(lambda x: x, "/core15_1") - res2 = alc15_1.parse("@core15_1") - assert res2.matched is False - assert cap["output"] == '无法解析 "@core15_1"。您想要输入的是不是 "/core15_1" ?' + + res2 = alc15_1.parse("core15_1") + assert res2.matched is False + assert res2.output == '无法解析 "core15_1"。您想要输入的是不是 "/core15_1" ?' + + res3 = alc15_1.parse("@core15_1") + assert res3.matched is False + assert res3.output == '无法解析 "@core15_1"。您想要输入的是不是 "/core15_1" ?' alc15_3 = Alconna("core15_3", Option("rank", compact=True), meta=CommandMeta(fuzzy_match=True)) - with output_manager.capture("core15_3") as cap: - output_manager.set_action(lambda x: x, "core15_3") - res6 = alc15_3.parse("core15_3 runk") - assert res6.matched is False - assert cap["output"] == '无法解析 "runk"。您想要输入的是不是 "rank" ?' + + res4 = alc15_3.parse("core15_3 runk") + assert res4.matched is False + assert res4.output == '无法解析 "runk"。您想要输入的是不是 "rank" ?' def test_shortcut(): - from arclet.alconna import namespace, output_manager + from arclet.alconna import namespace with namespace("test16") as ns: ns.disable_builtin_options = set() @@ -493,16 +487,15 @@ def wrapper(slot, content): alc16_6 = Alconna("core16_6", Args["bar", str]) alc16_6.shortcut("test(?P.+)?", fuzzy=False, wrapper=wrapper, arguments=["{bar}"]) assert alc16_6.parse("testabc").bar == "abc" - - with output_manager.capture("core16_6") as cap: - output_manager.set_action(lambda x: x, "core16_6") - alc16_6.parse("testhelp") - assert cap["output"] == """\ + assert ( + alc16_6.parse("testhelp").output + == """\ core16_6 Unknown 快捷命令: 'test(?P.+)?' => core16_6 {bar}\ """ + ) alc16_7 = Alconna("core16_7", Args["bar", str]) alc16_7.shortcut("test 123", {"args": ["abc"]}) @@ -550,10 +543,7 @@ def wrapper2(slot, content): return content alc16_13.shortcut( - r"(?i:io)(?i:rank)\s*(?P[a-zA-Z+-]*)", - command="core16_13 rank {rank}", - fuzzy=False, - wrapper=wrapper2 + r"(?i:io)(?i:rank)\s*(?P[a-zA-Z+-]*)", command="core16_13 rank {rank}", fuzzy=False, wrapper=wrapper2 ) assert alc16_13.parse("iorank x").matched @@ -569,9 +559,6 @@ def wrapper2(slot, content): def test_help(): - from arclet.alconna import output_manager - from arclet.alconna.exceptions import SpecialOptionTriggered - alc17 = Alconna( "core17", Option("foo", Args["bar", str], help_text="Foo bar"), @@ -579,43 +566,34 @@ def test_help(): Subcommand("add", Args["bar", str], help_text="Add bar"), Subcommand("del", Args["bar", str], help_text="Del bar"), ) - with output_manager.capture("core17") as cap: - output_manager.set_action(lambda x: x, "core17") - res = alc17.parse("core17 --help") - assert isinstance(res.error_info, SpecialOptionTriggered) - assert cap["output"] == ( - "core17 \n" - "Unknown\n" - "\n" - "可用的子命令有:\n" - "* Add bar\n" - " add \n" - "* Del bar\n" - " del \n" - "可用的选项有:\n" - "* Foo bar\n" - " foo \n" - "* Baz qux\n" - " baz \n" - ) - with output_manager.capture("core17") as cap: - alc17.parse("core17 --help foo") - assert cap["output"] == "core17 foo \nFoo bar" - with output_manager.capture("core17") as cap: - alc17.parse("core17 foo --help") - assert cap["output"] == "core17 foo \nFoo bar" - with output_manager.capture("core17") as cap: - alc17.parse("core17 --help baz") - assert cap["output"] == "core17 baz \nBaz qux" - with output_manager.capture("core17") as cap: - alc17.parse("core17 baz --help") - assert cap["output"] == "core17 baz \nBaz qux" - with output_manager.capture("core17") as cap: - alc17.parse("core17 add --help") - assert cap["output"] == "core17 add \nAdd bar" - with output_manager.capture("core17") as cap: - alc17.parse("core17 del --help") - assert cap["output"] == "core17 del \nDel bar" + res = alc17.parse("core17 --help") + assert res.output == ( + "core17 \n" + "Unknown\n" + "\n" + "可用的子命令有:\n" + "* Add bar\n" + " add \n" + "* Del bar\n" + " del \n" + "可用的选项有:\n" + "* Foo bar\n" + " foo \n" + "* Baz qux\n" + " baz \n" + ) + + assert alc17.parse("core17 --help foo").output == "core17 foo \nFoo bar" + + assert alc17.parse("core17 foo --help").output == "core17 foo \nFoo bar" + + assert alc17.parse("core17 --help baz").output == "core17 baz \nBaz qux" + + assert alc17.parse("core17 baz --help").output == "core17 baz \nBaz qux" + + assert alc17.parse("core17 add --help").output == "core17 add \nAdd bar" + + assert alc17.parse("core17 del --help").output == "core17 del \nDel bar" alc17_2 = Alconna( "core17_2", Subcommand( @@ -625,12 +603,11 @@ def test_help(): help_text="sub Foo", ), ) - with output_manager.capture("core17_2") as cap: - alc17_2.parse("core17_2 --help foo bar") - assert cap["output"] == "core17_2 foo bar \nFoo bar" - with output_manager.capture("core17_2") as cap: - alc17_2.parse("core17_2 --help foo") - assert cap["output"] == "core17_2 foo \nsub Foo\n\n可用的选项有:\n* Foo bar\n bar \n" + assert alc17_2.parse("core17_2 --help foo bar").output == "core17_2 foo bar \nFoo bar" + assert ( + alc17_2.parse("core17_2 --help foo").output + == "core17_2 foo \nsub Foo\n\n可用的选项有:\n* Foo bar\n bar \n" + ) def test_hide_annotation(): @@ -647,32 +624,64 @@ def test_args_notice(): def test_completion(): - from arclet.alconna.exceptions import SpecialOptionTriggered - - alc20 = ( - "core20" - + Option("fool") - + Option( + alc20 = Alconna( + "core20", + Option("fool"), + Option( "foo", Args["bar", "a|b|c", Field(completion=lambda: "choose a, b or c")], - ) - + Option( + ), + Option( "off", Args["baz", "aaa|aab|abc", Field(completion=lambda: ["use aaa", "use aab", "use abc"])], - ) - + Args["test", int, Field(1, completion=lambda: "try -1")] + ), + Args["test", int, Field(1, completion=lambda: "try -1")] ) - - alc20.parse("core20 --comp") - alc20.parse("core20 f --comp") - alc20.parse("core20 fo --comp") - alc20.parse("core20 foo --comp") - alc20.parse("core20 fool --comp") - alc20.parse("core20 off b --comp") + assert alc20.parse("core20 --comp").output == ( + """\ +以下是建议的输入: +* fool +* foo +* off\ +""" + ) + assert alc20.parse("core20 f --comp").output == ( + """\ +以下是建议的输入: +* fool +* foo +* off\ +""" + ) + assert alc20.parse("core20 fo --comp").output == ( + """\ +以下是建议的输入: +* fool +* foo\ +""" + ) +# assert alc20.parse("core20 f --comp").output == """以下是建议的输入: +# * fool +# * foo +# * off""" +# assert alc20.parse("core20 fo --comp").output == """以下是建议的输入: +# * fool +# * foo""" +# assert alc20.parse("core20 foo --comp").output == """以下是建议的输入: +# * bar: choose a, b or c""" +# assert alc20.parse("core20 fool --comp").output == """以下是建议的输入: +# * foo +# * off""" +# assert alc20.parse("core20 off b --comp").output == """以下是建议的输入: +# * baz: use aaa +# * baz: use aab +# * baz: use abc""" alc20_1 = Alconna("core20_1", Args["foo", int], Option("bar")) res = alc20_1.parse("core20_1 -cp") - assert isinstance(res.error_info, SpecialOptionTriggered) + assert res.output == """以下是建议的输入: +* +* bar""" def test_completion_interface(): @@ -778,9 +787,11 @@ class A: assert alc23.parse(["core23 bar baz --qux", A(), "123"]).matched assert not alc23.parse(["core23 bar baz", A(), "--qux 123"]).matched assert alc23.parse(["core23 bar baz --qux", A(), "123"]).query("Bar.Baz.qux.value") is Ellipsis - print("") + # alc23.parse("core23 --help") - alc23.parse("core23 bar baz --help") + assert alc23.parse("core23 bar baz --help").output == ( + "core23 bar baaz│baz \n" "test nest subcommand; deep 2\n" "\n" "可用的选项有:\n" "* qux\n" " --qux \n" + ) alc23_1 = Alconna( "core23_1", @@ -818,9 +829,7 @@ def test_action(): Option("-x|--xyz", action=count), Option("--q", action=count), ) - res = alc24_2.parse( - "core24_2 -A --a -vvv -x -x --xyzxyz -Fabc -Fdef --flag xyz --i 4 --i 5 --q --qq" - ) + res = alc24_2.parse("core24_2 -A --a -vvv -x -x --xyzxyz -Fabc -Fdef --flag xyz --i 4 --i 5 --q --qq") assert res.query[int]("i.foo") == 5 assert res.query[List[int]]("a.value") == [1, 1] assert res.query[List[str]]("flag.flag") == ["abc", "def", "xyz"] @@ -957,13 +966,20 @@ def test_tips(): core27 = Alconna( "core27", - Args["arg1", Literal["1", "2"], Field(unmatch_tips=lambda x: f"参数arg必须是1或2哦,不能是{x}", missing_tips=lambda: "缺少了arg1参数哦")], + Args[ + "arg1", + Literal["1", "2"], + Field(unmatch_tips=lambda x: f"参数arg必须是1或2哦,不能是{x}", missing_tips=lambda: "缺少了arg1参数哦"), + ], Args["arg2", Literal["1", "2"], Field(missing_tips=lambda: "缺少了arg2参数哦")], ) assert core27.parse("core27 1 1").matched assert str(core27.parse("core27 3 1").error_info) == "参数arg必须是1或2哦,不能是3" assert str(core27.parse("core27 1").error_info) == "缺少了arg2参数哦" - assert str(core27.parse("core27 1 3").error_info) in ("参数 '3' 不正确, 其应该符合 \"'1'|'2'\"", "参数 '3' 不正确, 其应该符合 \"'2'|'1'\"") + assert str(core27.parse("core27 1 3").error_info) in ( + "参数 '3' 不正确, 其应该符合 \"'1'|'2'\"", + "参数 '3' 不正确, 其应该符合 \"'2'|'1'\"", + ) assert str(core27.parse("core27").error_info) == "缺少了arg1参数哦" diff --git a/tests/devtool.py b/tests/devtool.py index e0c2a341..5094e06a 100644 --- a/tests/devtool.py +++ b/tests/devtool.py @@ -43,7 +43,7 @@ def analyse_args( command: list[str | Any], raise_exception: bool = True, context_style: Literal["bracket", "parentheses"] | None = None, - **kwargs + **kwargs, ): meta = CommandMeta(keep_crlf=False, fuzzy_match=False, raise_exception=raise_exception, context_style=context_style) argv: Argv[DataCollection] = Argv(meta, dev_space) @@ -66,7 +66,7 @@ def analyse_header( compact: bool = False, raise_exception: bool = True, context_style: Literal["bracket", "parentheses"] | None = None, - **kwargs + **kwargs, ): meta = CommandMeta(keep_crlf=False, fuzzy_match=False, raise_exception=raise_exception, context_style=context_style) argv: Argv[DataCollection] = Argv(meta, dev_space, separators=sep) @@ -86,7 +86,7 @@ def analyse_option( command: DataCollection[str | Any], raise_exception: bool = True, context_style: Literal["bracket", "parentheses"] | None = None, - **kwargs + **kwargs, ): meta = CommandMeta(keep_crlf=False, fuzzy_match=False, raise_exception=raise_exception, context_style=context_style) argv: Argv[DataCollection] = Argv(meta, dev_space) @@ -113,7 +113,7 @@ def analyse_subcommand( command: DataCollection[str | Any], raise_exception: bool = True, context_style: Literal["bracket", "parentheses"] | None = None, - **kwargs + **kwargs, ): meta = CommandMeta(keep_crlf=False, fuzzy_match=False, raise_exception=raise_exception, context_style=context_style) argv: Argv[DataCollection] = Argv(meta, dev_space) diff --git a/tests/sistana/asserts.py b/tests/sistana/asserts.py index e5d2bbeb..b6f827ec 100644 --- a/tests/sistana/asserts.py +++ b/tests/sistana/asserts.py @@ -39,6 +39,7 @@ def expect_completed(self): def expect_uncompleted(self): assert self.exit_reason != LoopflowExitReason.completed + @dataclass class SnapshotTest: snapshot: AnalyzeSnapshot @@ -64,12 +65,12 @@ class BufferTest: def expect_empty(self): with pytest.raises(OutOfData): self.buffer.next("") - + def expect_non_empty(self): v = None with suppress(OutOfData): v = self.buffer.next() - + assert v is not None def expect_ahead(self): @@ -124,7 +125,7 @@ class TrackTest: def expect_emitted(self, expected: bool = True): assert self.track.emitted == expected - + def expect_cursor(self, expected: int): assert self.track.cursor == expected @@ -152,6 +153,7 @@ def header(self): return FragmentTest(self.mix, self.track, self.track.header) + @dataclass class FragmentTest: mix: Mix @@ -161,14 +163,14 @@ class FragmentTest: @property def assigned(self): return self.fragment.name in self.mix.assignes - + def expect_assigned(self, expected: bool = True): assert self.assigned == expected @property def value(self): return self.mix.assignes[self.fragment.name] - + def expect_value(self, expected): assert self.value == expected diff --git a/tests/sistana/test_fragment.py b/tests/sistana/test_fragment.py index fdcf7682..fd33ce04 100644 --- a/tests/sistana/test_fragment.py +++ b/tests/sistana/test_fragment.py @@ -10,7 +10,11 @@ def test_assert_fragments_order_valid(): - fragments = [_Fragment(name="frag1"), _Fragment(name="frag2", default=Value("default")), _Fragment(name="frag3", variadic=True)] + fragments = [ + _Fragment(name="frag1"), + _Fragment(name="frag2", default=Value("default")), + _Fragment(name="frag3", variadic=True), + ] assert_fragments_order(fragments) diff --git a/tests/sistana/test_mix.py b/tests/sistana/test_mix.py index 2052c9ab..a29512dd 100644 --- a/tests/sistana/test_mix.py +++ b/tests/sistana/test_mix.py @@ -44,10 +44,11 @@ def test_assignable(): def test_header_edge_case(): class DoNothingRx(Rx): - def receive(self, fetch, prev, put): - ... - - pat = SubcommandPattern.build("test", header_fragment=Fragment("arg1", receiver=DoNothingRx(), default=Value("default"))) + def receive(self, fetch, prev, put): ... + + pat = SubcommandPattern.build( + "test", header_fragment=Fragment("arg1", receiver=DoNothingRx(), default=Value("default")) + ) a, sn, bf = analyze(pat, Buffer(["test"])) a.expect_completed() diff --git a/tests/sistana/test_pattern.py b/tests/sistana/test_pattern.py index c220cf18..a25c9ab5 100644 --- a/tests/sistana/test_pattern.py +++ b/tests/sistana/test_pattern.py @@ -57,4 +57,4 @@ def test_add_option(): with pytest.raises(ValueError, match="header_separators must be used with fragments"): pat = SubcommandPattern.build("test", separators="|") - pat.option("name", aliases=["--name"], header_separators="=") \ No newline at end of file + pat.option("name", aliases=["--name"], header_separators="=")