-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathserver.py
64 lines (46 loc) · 1.61 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import asyncio
from typing import AsyncIterator
from pydantic import Field
from coagent.core import (
AgentSpec,
BaseAgent,
Context,
handler,
idle_loop,
Message,
new,
set_stderr_logger,
)
from coagent.runtimes import NATSRuntime, HTTPRuntime
class Ping(Message):
pass
class PartialPong(Message):
content: str = Field(..., description="The content of the Pong message.")
class StreamServer(BaseAgent):
"""The Stream Pong Server."""
@handler
async def handle(self, msg: Ping, ctx: Context) -> AsyncIterator[PartialPong]:
"""Handle the Ping message and return a stream of PartialPong messages."""
words = ("Hi ", "there, ", "this ", "is ", "the ", "Pong ", "server.")
for word in words:
await asyncio.sleep(0.6)
yield PartialPong(content=word)
stream_server = AgentSpec("stream_server", new(StreamServer))
async def main(server: str, auth: str):
if server.startswith("nats://"):
runtime = NATSRuntime.from_servers(server)
elif server.startswith(("http://", "https://")):
runtime = HTTPRuntime.from_server(server, auth)
else:
raise ValueError(f"Unsupported server: {server}")
async with runtime:
await runtime.register(stream_server)
await idle_loop()
if __name__ == "__main__":
set_stderr_logger("TRACE")
parser = argparse.ArgumentParser()
parser.add_argument("--server", type=str, default="nats://localhost:4222")
parser.add_argument("--auth", type=str, default="")
args = parser.parse_args()
asyncio.run(main(args.server, args.auth))