Skip to content
This repository has been archived by the owner on Dec 30, 2024. It is now read-only.

Commit

Permalink
refactor: support Suno plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
FuseFairy committed Apr 21, 2024
1 parent 26a3001 commit 52b39a5
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 19 deletions.
16 changes: 11 additions & 5 deletions cogs/edgegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from discord import app_commands
from core.classes import Cog_Extension
from src.log import setup_logger
from src.user_chatbot import set_chatbot, get_users_chatbot, set_dalle3_unofficial_apikey
from src.user_chatbot import set_chatbot, get_users_chatbot
from dotenv import load_dotenv

load_dotenv()
Expand Down Expand Up @@ -52,27 +52,33 @@ async def dalle3_setting(self, interaction: discord.Interaction, api_key: str):
if allowed_channel_id and int(allowed_channel_id) != interaction.channel_id:
await interaction.followup.send(f"> **Command can only used on <#{allowed_channel_id}>**")
return
await set_dalle3_unofficial_apikey(interaction.user.id, api_key)
await set_chatbot(interaction.user.id, dalle3_apikey=api_key)
await interaction.followup.send("> **Setting success!**")

# Chat with Copilot.
@app_commands.command(name="copilot", description="Create thread for conversation.")
@app_commands.choices(version=[app_commands.Choice(name="default", value="default"), app_commands.Choice(name="jail_break", value="jailbreak")])
@app_commands.choices(style=[app_commands.Choice(name="Creative", value="creative"), app_commands.Choice(name="Balanced", value="balanced"), app_commands.Choice(name="Precise", value="precise")])
@app_commands.choices(type=[app_commands.Choice(name="private", value="private"), app_commands.Choice(name="public", value="public")])
async def chat(self, interaction: discord.Interaction, version: app_commands.Choice[str], style: app_commands.Choice[str], type: app_commands.Choice[str]):
@app_commands.choices(plugin=[app_commands.Choice(name="Suno", value="suno")])
async def chat(self, interaction: discord.Interaction, version: app_commands.Choice[str], style: app_commands.Choice[str],
type: app_commands.Choice[str], plugin: app_commands.Choice[str]=None):
await interaction.response.defer(thinking=True)
allowed_channel_id = os.getenv("CHAT_CHANNEL_ID")
if allowed_channel_id and int(allowed_channel_id) != interaction.channel_id:
await interaction.followup.send(f"> **Command can only used on <#{allowed_channel_id}>**")
return
if isinstance(interaction.channel, discord.Thread):
await interaction.followup.send("> This command is disabled in thread.")
await interaction.followup.send("> **This command is disabled in thread.**")
return
if version.value == "jailbreak" and plugin != None:
await interaction.followup.send("> **jail break is not support plugins.**")
return

user_id = interaction.user.id
try:
await set_chatbot(user_id=user_id, conversation_style=style.value, version=version.value)
plugin = plugin.value if plugin else None
await set_chatbot(user_id=user_id, conversation_style=style.value, version=version.value, plugin=plugin)
except Exception as e:
await interaction.followup.send(f"> **ERROR:{e}**")
return
Expand Down
22 changes: 19 additions & 3 deletions src/bing_chat/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .jail_break import sydney, config
from .button_view import ButtonView
from re_edge_gpt import ImageGenAsync
from re_edge_gpt.plugins.suno import generate_suno_music
from ..image.image_create import concatenate_images
from dotenv import load_dotenv

Expand All @@ -17,7 +18,7 @@

config = config.Config()

async def send_message(user_chatbot, user_message: str, image: str, interaction: discord.Interaction=None):
async def send_message(user_chatbot, user_message: str, image: str, plugin: str=None, interaction: discord.Interaction=None):
reply = ''
text = ''
link_embed = ''
Expand Down Expand Up @@ -75,17 +76,32 @@ async def send_message(user_chatbot, user_message: str, image: str, interaction:
else:
conversation_style=ConversationStyle.balanced

add_options = None
plugins=None
if plugin == "suno":
add_options = ["014CB21D"]
plugins = [{"Id": "c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "Category": 1}]
chatbot: Chatbot
reply = await chatbot.ask(
prompt=user_message,
conversation_style=conversation_style,
simplify_response=True,
attachment={"image_url":f"{image}"}
attachment={"image_url":f"{image}"},
add_options=add_options,
plugins=plugins,
message_type="GenerateContentQuery"
)

music = None
try:
if plugin == "suno":
music = await generate_suno_music(user_chatbot.cookies, reply.get("messageId"), reply.get("requestId"))
except:
pass

# Get reply text
suggest_responses = reply["suggestions"]
text = f"{reply['text']}"
text = f"{reply['text']}\n\n[Video]({music['Video']})" if music and music['Video'] else f"{reply['text']}"
urls = [(i+1, x, reply["source_values"][i]) for i, x in enumerate(reply["source_keys"])]
end = text.find("Generating answers for you...")
text = text[:end] if end != -1 else text
Expand Down
27 changes: 16 additions & 11 deletions src/user_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
logger = setup_logger(__name__)
users_chatbot = {}

async def set_chatbot(user_id, conversation_style=None, version=None, cookies=None, dalle3_apikey=None):
async def set_chatbot(user_id, conversation_style=None, version=None, cookies=None, dalle3_apikey=None, plugin=None):
if user_id not in users_chatbot:
users_chatbot[user_id] = UserChatbot(user_id)

Expand All @@ -32,13 +32,14 @@ async def set_chatbot(user_id, conversation_style=None, version=None, cookies=No
if cookies:
users_chatbot[user_id].set_cookies(cookies)

if dalle3_apikey:
users_chatbot[user_id].set_dalle3_apikey(dalle3_apikey)

if plugin:
users_chatbot[user_id].set_plugin(plugin)

def get_users_chatbot():
return users_chatbot

async def set_dalle3_unofficial_apikey(user_id, dalle3_apikey: str):
if user_id not in users_chatbot:
users_chatbot[user_id] = UserChatbot(user_id)
users_chatbot[user_id].set_dalle3_apikey(dalle3_apikey)

class UserChatbot():
def __init__(self, user_id):
Expand All @@ -47,6 +48,7 @@ def __init__(self, user_id):
self.sem_create_image_dalle3 = Semaphore(1)
self.cookies = None
self.dalle3_unoffcial_apikey=None
self.plugin = None
self.chatbot = None
self.thread = None
self.jailbreak = None
Expand Down Expand Up @@ -87,6 +89,9 @@ def update_chat_history(self, text: str):
def get_chatbot(self):
return self.chatbot

def set_plugin(self, plugin: str):
self.plugin = plugin

async def initialize_chatbot(self, jailbreak: bool):
self.jailbreak = jailbreak

Expand All @@ -102,7 +107,7 @@ async def initialize_chatbot(self, jailbreak: bool):
self.cookies = json.load(file)

if not self.jailbreak:
self.chatbot = await Chatbot.create(cookies=self.cookies, mode="Bing")
self.chatbot = await Chatbot.create(cookies=self.cookies, mode="Bing", plugin_ids=[self.plugin])

async def send_message(self, message: str, interaction: discord.Interaction=None, image: str=None):
if not self.sem_send_message.locked():
Expand All @@ -114,10 +119,10 @@ async def send_message(self, message: str, interaction: discord.Interaction=None
async with self.sem_send_message:
if interaction:
if interaction.type == discord.InteractionType.component or self.thread == None:
await send_message(self, message, image, interaction)
await send_message(self, message, image, self.plugin, interaction)
else:
async with self.thread.typing():
await send_message(self, message, image)
await send_message(self, message, image, self.plugin)
else:
if interaction:
if not interaction.response.is_done():
Expand Down Expand Up @@ -148,11 +153,11 @@ async def create_image(self, interaction: discord.Interaction, prompt: str, serv
if not self.sem_create_image_dalle3.locked():
if self.dalle3_unoffcial_apikey == None and os.getenv("DALLE3_UNOFFICIAL_APIKEY"):
self.dalle3_unoffcial_apikey = os.getenv("DALLE3_UNOFFICIAL_APIKEY")
else:
elif self.dalle3_unoffcial_apikey == None:
await interaction.followup.send("> **ERROR:Please upload your api key.**")
return
async with self.sem_create_image_dalle3:
await create_image_dalle3(interaction, prompt, self, self.dalle3_unoffcial_apikey)
await create_image_dalle3(interaction, prompt, self)
else:
await interaction.followup.send("> **ERROR:Please wait for the previous command to complete.**")

Expand Down

0 comments on commit 52b39a5

Please sign in to comment.