Skip to content

Commit

Permalink
sistana: fix nepattern
Browse files Browse the repository at this point in the history
  • Loading branch information
GreyElaina committed Oct 10, 2024
1 parent 4b3058e commit e58e218
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 39 deletions.
40 changes: 19 additions & 21 deletions src/arclet/alconna/sistana/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __post_init__(self):

def apply_msgspec(self):
if self.type is None:
return
return self

t = self.type.value

Expand All @@ -58,40 +58,38 @@ def _transform(v: Segment):
return convert(str(v), t)

self.transformer = _transform

return self

def apply_nepattern(self, pat: BasePattern | None = None, capture_mode: bool = False):
if pat is None:
if self.type is None:
return
return self

from nepattern import type_parser
from nepattern import BasePattern

pat = type_parser(self.type.value)
pat = BasePattern.to(self.type.value)
assert pat is not None

from nepattern import MatchMode
def _validate(v: Segment):
if isinstance(v, (Quoted, UnmatchedQuoted)):
if isinstance(v.ref, str):
v = str(v)
else:
v = v.ref[0]
return pat.validate(v).success

if capture_mode:
if pat.mode in (MatchMode.REGEX_MATCH, MatchMode.REGEX_CONVERT):
self.capture = RegexCapture(pat.regex_pattern)
else:
self.capture = RegexCapture(pat.alias) # type: ignore
else:
self.validator = _validate
if self.cast:
def _transform(v: Segment):

def _validate(v: Segment):
if isinstance(v, (Quoted, UnmatchedQuoted)):
if isinstance(v.ref, str):
v = str(v)
else:
v = v.ref
v = v.ref[0]

return pat.validate(v).success

self.validator = _validate

if self.cast:

def _transform(v: Segment):
return pat.validate(str(v)).value()
return pat.validate(v).value()

self.transformer = _transform
return self
2 changes: 1 addition & 1 deletion tests/sistana/test_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,4 @@ def test_regex_capture():
assert frag.value.group() == "123"

with pytest.raises(UnexpectedType):
a, sn, bf = analyze(pat, Buffer(["test --name '123", 1]))
a, sn, bf = analyze(pat, Buffer(["test --name '123", 1]))
63 changes: 46 additions & 17 deletions tests/sistana/test_fragment.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,63 @@
import pytest
from elaina_segment import Buffer

from arclet.alconna.sistana.fragment import Fragment
from arclet.alconna.sistana.model.fragment import _Fragment, assert_fragments_order
from arclet.alconna.sistana.model.pattern import SubcommandPattern
from arclet.alconna.sistana.some import Value

from .asserts import analyze


def test_assert_fragments_order_valid():
fragments = [
_Fragment(name="frag1"),
_Fragment(name="frag2", default=Value("default")),
_Fragment(name="frag3", variadic=True)
]
fragments = [_Fragment(name="frag1"), _Fragment(name="frag2", default=Value("default")), _Fragment(name="frag3", variadic=True)]
assert_fragments_order(fragments)


def test_assert_fragments_order_required_after_optional():
fragments = [
_Fragment(name="frag1", default=Value("default")),
_Fragment(name="frag2")
]
fragments = [_Fragment(name="frag1", default=Value("default")), _Fragment(name="frag2")]
with pytest.raises(ValueError, match="Found a required fragment after an optional fragment, which is not allowed."):
assert_fragments_order(fragments)


def test_assert_fragments_order_variadic_with_default():
fragments = [
_Fragment(name="frag1", variadic=True, default=Value("default"))
]
fragments = [_Fragment(name="frag1", variadic=True, default=Value("default"))]
with pytest.raises(ValueError, match="A variadic fragment cannot have a default value."):
assert_fragments_order(fragments)


def test_assert_fragments_order_fragment_after_variadic():
fragments = [
_Fragment(name="frag1", variadic=True),
_Fragment(name="frag2")
]
fragments = [_Fragment(name="frag1", variadic=True), _Fragment(name="frag2")]
with pytest.raises(ValueError, match="Found fragment after a variadic fragment, which is not allowed."):
assert_fragments_order(fragments)
assert_fragments_order(fragments)


def test_nepattern():
from nepattern import WIDE_BOOLEAN

pat = SubcommandPattern.build("test")

pat.option("--foo", Fragment("foo", type=Value(int)))
pat.option("--bar", Fragment("bar", type=Value(float)))
pat.option("--baz", Fragment("baz", type=Value(bool)))
pat.option("--qux", Fragment("qux").apply_nepattern(WIDE_BOOLEAN))

a, sn, bf = analyze(pat, Buffer(["test --foo 123 --bar 123.456 --baz true --qux yes"]))
a.expect_completed()
sn.expect_determined()

frag = sn.mix[("test",), "--foo"]["foo"]
frag.expect_assigned()
frag.expect_value(123)

frag = sn.mix[("test",), "--bar"]["bar"]
frag.expect_assigned()
frag.expect_value(123.456)

frag = sn.mix[("test",), "--baz"]["baz"]
frag.expect_assigned()
frag.expect_value(True)

frag = sn.mix[("test",), "--qux"]["qux"]
frag.expect_assigned()
frag.expect_value(True)

0 comments on commit e58e218

Please sign in to comment.