diff --git a/cogs/edgegpt.py b/cogs/edgegpt.py index d1664fb..9d9eabf 100644 --- a/cogs/edgegpt.py +++ b/cogs/edgegpt.py @@ -4,7 +4,7 @@ from discord import app_commands from core.classes import Cog_Extension from src.log import setup_logger -from src.user_chatbot import set_chatbot, get_users_chatbot, set_dalle3_unofficial_apikey +from src.user_chatbot import set_chatbot, get_users_chatbot from dotenv import load_dotenv load_dotenv() @@ -52,7 +52,7 @@ async def dalle3_setting(self, interaction: discord.Interaction, api_key: str): if allowed_channel_id and int(allowed_channel_id) != interaction.channel_id: await interaction.followup.send(f"> **Command can only used on <#{allowed_channel_id}>**") return - await set_dalle3_unofficial_apikey(interaction.user.id, api_key) + await set_chatbot(interaction.user.id, dalle3_apikey=api_key) await interaction.followup.send("> **Setting success!**") # Chat with Copilot. @@ -60,19 +60,25 @@ async def dalle3_setting(self, interaction: discord.Interaction, api_key: str): @app_commands.choices(version=[app_commands.Choice(name="default", value="default"), app_commands.Choice(name="jail_break", value="jailbreak")]) @app_commands.choices(style=[app_commands.Choice(name="Creative", value="creative"), app_commands.Choice(name="Balanced", value="balanced"), app_commands.Choice(name="Precise", value="precise")]) @app_commands.choices(type=[app_commands.Choice(name="private", value="private"), app_commands.Choice(name="public", value="public")]) - async def chat(self, interaction: discord.Interaction, version: app_commands.Choice[str], style: app_commands.Choice[str], type: app_commands.Choice[str]): + @app_commands.choices(plugin=[app_commands.Choice(name="Suno", value="suno")]) + async def chat(self, interaction: discord.Interaction, version: app_commands.Choice[str], style: app_commands.Choice[str], + type: app_commands.Choice[str], plugin: app_commands.Choice[str]=None): await interaction.response.defer(thinking=True) allowed_channel_id = os.getenv("CHAT_CHANNEL_ID") if allowed_channel_id and int(allowed_channel_id) != interaction.channel_id: await interaction.followup.send(f"> **Command can only used on <#{allowed_channel_id}>**") return if isinstance(interaction.channel, discord.Thread): - await interaction.followup.send("> This command is disabled in thread.") + await interaction.followup.send("> **This command is disabled in thread.**") + return + if version.value == "jailbreak" and plugin != None: + await interaction.followup.send("> **jail break is not support plugins.**") return user_id = interaction.user.id try: - await set_chatbot(user_id=user_id, conversation_style=style.value, version=version.value) + plugin = plugin.value if plugin else None + await set_chatbot(user_id=user_id, conversation_style=style.value, version=version.value, plugin=plugin) except Exception as e: await interaction.followup.send(f"> **ERROR:{e}**") return diff --git a/src/bing_chat/response.py b/src/bing_chat/response.py index 2d27e75..fb4412f 100644 --- a/src/bing_chat/response.py +++ b/src/bing_chat/response.py @@ -8,6 +8,7 @@ from .jail_break import sydney, config from .button_view import ButtonView from re_edge_gpt import ImageGenAsync +from re_edge_gpt.plugins.suno import generate_suno_music from ..image.image_create import concatenate_images from dotenv import load_dotenv @@ -17,7 +18,7 @@ config = config.Config() -async def send_message(user_chatbot, user_message: str, image: str, interaction: discord.Interaction=None): +async def send_message(user_chatbot, user_message: str, image: str, plugin: str=None, interaction: discord.Interaction=None): reply = '' text = '' link_embed = '' @@ -75,17 +76,32 @@ async def send_message(user_chatbot, user_message: str, image: str, interaction: else: conversation_style=ConversationStyle.balanced + add_options = None + plugins=None + if plugin == "suno": + add_options = ["014CB21D"] + plugins = [{"Id": "c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "Category": 1}] chatbot: Chatbot reply = await chatbot.ask( prompt=user_message, conversation_style=conversation_style, simplify_response=True, - attachment={"image_url":f"{image}"} + attachment={"image_url":f"{image}"}, + add_options=add_options, + plugins=plugins, + message_type="GenerateContentQuery" ) + music = None + try: + if plugin == "suno": + music = await generate_suno_music(user_chatbot.cookies, reply.get("messageId"), reply.get("requestId")) + except: + pass + # Get reply text suggest_responses = reply["suggestions"] - text = f"{reply['text']}" + text = f"{reply['text']}\n\n[Video]({music['Video']})" if music and music['Video'] else f"{reply['text']}" urls = [(i+1, x, reply["source_values"][i]) for i, x in enumerate(reply["source_keys"])] end = text.find("Generating answers for you...") text = text[:end] if end != -1 else text diff --git a/src/user_chatbot.py b/src/user_chatbot.py index 604a2a4..d2b2796 100644 --- a/src/user_chatbot.py +++ b/src/user_chatbot.py @@ -17,7 +17,7 @@ logger = setup_logger(__name__) users_chatbot = {} -async def set_chatbot(user_id, conversation_style=None, version=None, cookies=None, dalle3_apikey=None): +async def set_chatbot(user_id, conversation_style=None, version=None, cookies=None, dalle3_apikey=None, plugin=None): if user_id not in users_chatbot: users_chatbot[user_id] = UserChatbot(user_id) @@ -32,13 +32,14 @@ async def set_chatbot(user_id, conversation_style=None, version=None, cookies=No if cookies: users_chatbot[user_id].set_cookies(cookies) + if dalle3_apikey: + users_chatbot[user_id].set_dalle3_apikey(dalle3_apikey) + + if plugin: + users_chatbot[user_id].set_plugin(plugin) + def get_users_chatbot(): return users_chatbot - -async def set_dalle3_unofficial_apikey(user_id, dalle3_apikey: str): - if user_id not in users_chatbot: - users_chatbot[user_id] = UserChatbot(user_id) - users_chatbot[user_id].set_dalle3_apikey(dalle3_apikey) class UserChatbot(): def __init__(self, user_id): @@ -47,6 +48,7 @@ def __init__(self, user_id): self.sem_create_image_dalle3 = Semaphore(1) self.cookies = None self.dalle3_unoffcial_apikey=None + self.plugin = None self.chatbot = None self.thread = None self.jailbreak = None @@ -87,6 +89,9 @@ def update_chat_history(self, text: str): def get_chatbot(self): return self.chatbot + def set_plugin(self, plugin: str): + self.plugin = plugin + async def initialize_chatbot(self, jailbreak: bool): self.jailbreak = jailbreak @@ -102,7 +107,7 @@ async def initialize_chatbot(self, jailbreak: bool): self.cookies = json.load(file) if not self.jailbreak: - self.chatbot = await Chatbot.create(cookies=self.cookies, mode="Bing") + self.chatbot = await Chatbot.create(cookies=self.cookies, mode="Bing", plugin_ids=[self.plugin]) async def send_message(self, message: str, interaction: discord.Interaction=None, image: str=None): if not self.sem_send_message.locked(): @@ -114,10 +119,10 @@ async def send_message(self, message: str, interaction: discord.Interaction=None async with self.sem_send_message: if interaction: if interaction.type == discord.InteractionType.component or self.thread == None: - await send_message(self, message, image, interaction) + await send_message(self, message, image, self.plugin, interaction) else: async with self.thread.typing(): - await send_message(self, message, image) + await send_message(self, message, image, self.plugin) else: if interaction: if not interaction.response.is_done(): @@ -148,11 +153,11 @@ async def create_image(self, interaction: discord.Interaction, prompt: str, serv if not self.sem_create_image_dalle3.locked(): if self.dalle3_unoffcial_apikey == None and os.getenv("DALLE3_UNOFFICIAL_APIKEY"): self.dalle3_unoffcial_apikey = os.getenv("DALLE3_UNOFFICIAL_APIKEY") - else: + elif self.dalle3_unoffcial_apikey == None: await interaction.followup.send("> **ERROR:Please upload your api key.**") return async with self.sem_create_image_dalle3: - await create_image_dalle3(interaction, prompt, self, self.dalle3_unoffcial_apikey) + await create_image_dalle3(interaction, prompt, self) else: await interaction.followup.send("> **ERROR:Please wait for the previous command to complete.**")