From bce1333b2ce30b70b8660b36758a862990c0de8d Mon Sep 17 00:00:00 2001
From: Jialei <jialeicui@126.com>
Date: Sun, 14 Apr 2024 10:02:55 +0800
Subject: [PATCH] add azure speech service

---
 src/server/app.py                |  2 +-
 src/server/db/history_service.py |  9 +++----
 src/server/llm/llm.py            |  3 +--
 src/server/tts/__init__.py       |  0
 src/server/tts/azure.py          | 40 ++++++++++++++++++++++++++++++++
 5 files changed, 45 insertions(+), 9 deletions(-)
 create mode 100644 src/server/tts/__init__.py
 create mode 100644 src/server/tts/azure.py

diff --git a/src/server/app.py b/src/server/app.py
index 53a31dc..71defb5 100644
--- a/src/server/app.py
+++ b/src/server/app.py
@@ -68,7 +68,7 @@ async def message(body: MicoMessage) -> MessageResponse:
     if text in ["开灯", "关灯", "停", "大点声", "小点声", "几点了"]:
         return ignore_resp
 
-    if '后提醒我' in text:
+    if "后提醒我" in text:
         return ignore_resp
 
     # TODO support multiple sessions
diff --git a/src/server/db/history_service.py b/src/server/db/history_service.py
index cb25f42..2985159 100644
--- a/src/server/db/history_service.py
+++ b/src/server/db/history_service.py
@@ -17,17 +17,14 @@ class MessageRole(enum.Enum):
 
 class HistoryService(abc.ABC):
     @abc.abstractmethod
-    def save(self, message: str, role: MessageRole, provider: str) -> None:
-        ...
+    def save(self, message: str, role: MessageRole, provider: str) -> None: ...
 
     @abc.abstractmethod
-    def get(self, limit: int = 10, offset: int = 0) -> list[History]:
-        ...
+    def get(self, limit: int = 10, offset: int = 0) -> list[History]: ...
 
 
 class HistorySvcDummy(HistoryService):
-    def save(self, message: str, role: MessageRole, provider: str) -> None:
-        ...
+    def save(self, message: str, role: MessageRole, provider: str) -> None: ...
 
     def get(self, limit: int = 10, offset: int = 0) -> list[History]:
         return []
diff --git a/src/server/llm/llm.py b/src/server/llm/llm.py
index 7ac7337..59a0d49 100644
--- a/src/server/llm/llm.py
+++ b/src/server/llm/llm.py
@@ -32,8 +32,7 @@ def new_session(self) -> None:
             self.messages.append({"role": "system", "content": self.prompt})
 
     @abstractmethod
-    def round(self, text: str, temperature: float) -> ChatCompletionMessage:
-        ...
+    def round(self, text: str, temperature: float) -> ChatCompletionMessage: ...
 
     def _save_message(self, message: str | None, role: MessageRole) -> None:
         if message is None:
diff --git a/src/server/tts/__init__.py b/src/server/tts/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/server/tts/azure.py b/src/server/tts/azure.py
new file mode 100644
index 0000000..3d95b9e
--- /dev/null
+++ b/src/server/tts/azure.py
@@ -0,0 +1,40 @@
+import textwrap
+
+import requests
+
+
+class AzureTTS:
+    # https://learn.microsoft.com/en-us/azure/ai-services/speech-service/rest-text-to-speech?tabs=streaming
+    def __init__(
+        self, subscription_key: str, region: str, voice: str = "en-US-JessaNeural"
+    ):
+        self.subscription_key = subscription_key
+        self.region = region
+        self.voice = voice
+        self.base_url = (
+            f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
+        )
+
+    def synthesize(self, text, filename):
+        # Synthesize the text to speech
+        url = self.base_url
+        headers = {
+            "Authorization": f"Bearer {self.get_token()}",
+            "Content-Type": "application/ssml+xml",
+            "X-Microsoft-OutputFormat": "riff-24khz-16bit-mono-pcm",
+            "User-Agent": "tts",
+        }
+        body = textwrap.dedent(
+            f"""\
+            <speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>
+                <voice name='{self.voice}'>
+                    {text}
+                </voice>
+            </speak>"""
+        )
+        response = requests.post(url, headers=headers, data=body)
+        if response.status_code == 200:
+            with open(filename, "wb") as f:
+                f.write(response.content)
+        else:
+            print(f"Failed to synthesize speech: {response.text}")