diff --git a/.gitignore b/.gitignore index a9275646..88ee8825 100644 --- a/.gitignore +++ b/.gitignore @@ -147,6 +147,8 @@ pdl-live/package-lock.json # PDL trace +*_result.json +*_result.yaml *_trace.json # Built docs diff --git a/pdl/pdl_interpreter.py b/pdl/pdl_interpreter.py index 273c1ab4..211641e7 100644 --- a/pdl/pdl_interpreter.py +++ b/pdl/pdl_interpreter.py @@ -59,7 +59,9 @@ from .pdl_location_utils import append, get_loc_string from .pdl_parser import PDLParseError, parse_file from .pdl_scheduler import ( + CodeYieldResultMessage, ModelCallMessage, + ModelYieldResultMessage, YieldBackgroundMessage, YieldMessage, YieldResultMessage, @@ -286,7 +288,7 @@ def step_block_body( state, scope, block, loc ) if state.yield_result: - yield YieldResultMessage(result) + yield CodeYieldResultMessage(result) if state.yield_background: yield YieldBackgroundMessage(background) case GetBlock(get=var): @@ -889,7 +891,7 @@ def generate_client_response_streaming( role = None for chunk in msg_stream: if state.yield_result: - yield YieldResultMessage(chunk["content"]) + yield ModelYieldResultMessage(chunk["content"]) if state.yield_background: yield YieldBackgroundMessage([chunk]) if complete_msg is None: diff --git a/pdl/pdl_scheduler.py b/pdl/pdl_scheduler.py index 4fe21eb3..5d062ff7 100644 --- a/pdl/pdl_scheduler.py +++ b/pdl/pdl_scheduler.py @@ -40,6 +40,12 @@ def step_to_completion(gen: Generator[Any, Any, GeneratorReturnT]) -> GeneratorR return w.value +MODEL_COLOR = "\033[92m" # Green +CODE_COLOR = "\033[95m" # Purple +END_COLOR = "\033[0m" # End color +NO_COLOR = "" + + class MessageKind(Enum): RESULT = 0 BACKGROUND = 1 @@ -53,9 +59,20 @@ class YieldMessage: @dataclass class YieldResultMessage(YieldMessage): kind = MessageKind.RESULT + color = NO_COLOR result: Any +@dataclass +class ModelYieldResultMessage(YieldResultMessage): + color = MODEL_COLOR + + +@dataclass +class CodeYieldResultMessage(YieldResultMessage): + color = CODE_COLOR + + @dataclass class YieldBackgroundMessage(YieldMessage): kind = MessageKind.BACKGROUND @@ -92,6 +109,11 @@ def schedule( try: msg = gen.send(v) match msg: + case ModelYieldResultMessage( + result=result + ) | CodeYieldResultMessage(result=result): + print(msg.color + stringify(result) + END_COLOR, end="") + todo_next.append((i, gen, None)) case YieldResultMessage(result=result): print(stringify(result), end="") todo_next.append((i, gen, None)) diff --git a/tests/test_line_table.py b/tests/test_line_table.py index 8ebef126..c5e28fb7 100644 --- a/tests/test_line_table.py +++ b/tests/test_line_table.py @@ -1,14 +1,20 @@ from pdl.pdl_interpreter import generate +from pdl.pdl_scheduler import CODE_COLOR, END_COLOR, MODEL_COLOR def do_test(t, capsys): generate(t["file"], None, None, {}, None) captured = capsys.readouterr() - output = captured.out.split("\n") + output_string = remove_coloring(captured.out) + output = output_string.split("\n") print(output) assert set(output) == set(t["errors"]) +def remove_coloring(text): + return text.replace(MODEL_COLOR, "").replace(CODE_COLOR, "").replace(END_COLOR, "") + + line = { "file": "tests/data/line/hello.pdl", "errors": [