-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
custom error messages and dynamic response size #262
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
import contextlib | ||
import logging | ||
from datetime import datetime, timedelta, timezone | ||
from math import ceil, exp | ||
from typing import Optional | ||
|
||
import openai | ||
|
@@ -20,6 +22,10 @@ | |
mentions = AllowedMentions(everyone=False, users=False, roles=False, replied_user=True) | ||
model = "gpt-4o-mini" | ||
|
||
# weights/coefficients for sigmoid function | ||
a = 750 | ||
b = 7 | ||
c = 400 | ||
|
||
|
||
def clean(msg, *prefixes): | ||
|
@@ -31,36 +37,37 @@ def clean(msg, *prefixes): | |
class Summarise(commands.Cog): | ||
def __init__(self, bot: Bot): | ||
self.bot = bot | ||
self.cooldowns = {} | ||
openai.api_key = CONFIG.OPENAI_API_KEY | ||
|
||
def build_prompt(self, bullet_points, channel_name): | ||
|
||
bullet_points = "Put it in bullet points for readability." if bullet_points else "" | ||
prompt = f"""People yap too much, I don't want to read all of it. The topic is something related to {channel_name}. In 2 sentences or less give me the gist of what is being said. {bullet_points} Note that the messages are in reverse chronological order: | ||
""" | ||
return prompt | ||
|
||
|
||
def optional_context_manager(self, use: bool, cm: callable): | ||
if use: | ||
return cm() | ||
|
||
return contextlib.nullcontext() | ||
|
||
@commands.cooldown(CONFIG.SUMMARISE_LIMIT, CONFIG.SUMMARISE_COOLDOWN * 60, commands.BucketType.channel) | ||
|
||
@commands.hybrid_command(help=LONG_HELP_TEXT, brief=SHORT_HELP_TEXT) | ||
async def tldr( | ||
self, ctx: Context, number_of_messages: int = 100, bullet_point_output: bool = False, private_view: bool = False): | ||
number_of_messages = 400 if number_of_messages > 400 else number_of_messages | ||
if await self.in_cooldown(ctx): | ||
return | ||
|
||
number_of_messages = CONFIG.SUMMARISE_MESSAGE_LIMIT if number_of_messages > CONFIG.SUMMARISE_MESSAGE_LIMIT else number_of_messages | ||
|
||
# avoid banned users | ||
if not await is_author_banned_openai(ctx): | ||
await ctx.send("You are banned from OpenAI!") | ||
return | ||
|
||
# get the last "number_of_messages" messages from the current channel and build the prompt | ||
prompt = self.build_prompt(bullet_point_output, ctx.channel) | ||
prompt = self.build_prompt(bullet_point_output, ctx.channel, self.sigmoid(number_of_messages)) | ||
|
||
messages = ctx.channel.history(limit=number_of_messages) | ||
messages = await self.create_message(messages, prompt) | ||
messages = await self.create_message(messages, prompt, ctx) | ||
|
||
# send the prompt to the ai overlords to process | ||
async with self.optional_context_manager(not private_view, ctx.typing): | ||
|
@@ -71,6 +78,28 @@ async def tldr( | |
prev = await prev.reply(content, allowed_mentions=mentions, ephemeral=private_view) | ||
|
||
|
||
|
||
async def in_cooldown(self, ctx): | ||
now = datetime.now(timezone.utc) | ||
# channel based cooldown | ||
if self.cooldowns.get(ctx.channel.id): | ||
# check that message limit hasn't been reached | ||
if CONFIG.SUMMARISE_LIMIT <= self.cooldowns[ctx.channel.id][1]: | ||
|
||
message_time = self.cooldowns[ctx.channel.id][0] | ||
cutoff = message_time + timedelta(minutes=CONFIG.SUMMARISE_COOLDOWN) | ||
# check that message time + cooldown time period is still in the future | ||
if now < cutoff: | ||
await ctx.reply("STFU!! Wait " + str(int((cutoff - now).total_seconds())) + " Seconds. You are on Cool Down." ) | ||
return True | ||
else: | ||
self.cooldowns[ctx.channel.id] = [now, 1] # reset the cooldown | ||
else: | ||
self.cooldowns[ctx.channel.id][1]+=1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did ruff not pick up on spaces either side of the +=? |
||
else: | ||
self.cooldowns[ctx.channel.id] = [now, 1] | ||
return False | ||
|
||
async def dispatch_api(self, messages) -> Optional[str]: | ||
logging.info(f"Making OpenAI request: {messages}") | ||
|
||
|
@@ -85,19 +114,30 @@ async def dispatch_api(self, messages) -> Optional[str]: | |
reply = clean(reply, "Apollo: ", "apollo: ", name) | ||
return reply | ||
|
||
async def create_message(self, message_chain, prompt): | ||
async def create_message(self, message_chain, prompt, ctx): | ||
# get initial prompt | ||
initial = prompt + "\n" | ||
|
||
# for each message, append it to the prompt as follows --- author : message \n | ||
message_length = 0 | ||
async for msg in message_chain: | ||
if CONFIG.AI_INCLUDE_NAMES and msg.author != self.bot.user: | ||
initial += msg.author.name + ":" + msg.content + "\n" | ||
|
||
message_length += len(msg.content.split()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you want to add the thing to only go back as far as the last tldr (perhaps with an option to override) |
||
initial += msg.author.name + ": " + msg.content + "\n" | ||
messages = [dict(role="system", content=initial)] | ||
|
||
return messages | ||
|
||
|
||
def build_prompt(self, bullet_points, channel_name, response_size): | ||
|
||
bullet_points = "Put it in bullet points for readability." if bullet_points else "" | ||
prompt = f"""People yap too much, I don't want to read all of it. The topic is related to {channel_name}. In {response_size} words or less give me the gist of what is being said. {bullet_points} Note that the messages are in reverse chronological order: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add this to the config file too, so we can experiment slightly I reckon this could do with a bit of prompt tuning, I don't really know what I'm talking about, but maybe something like this one. At least something to make it waffle less and be more precise |
||
""" | ||
return prompt | ||
|
||
def sigmoid(self, x): | ||
return int(ceil(c / (1 + b * exp((-x)/ a)))) | ||
|
||
async def setup(bot: Bot): | ||
await bot.add_cog(Summarise(bot)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ def __init__(self, filepath: str): | |
self.LIEGE_CHANCELLOR_ID: int = parsed.get("liege_chancellor_id") | ||
self.SUMMARISE_LIMIT: int = parsed.get("summarise_limit") | ||
self.SUMMARISE_COOLDOWN: int = parsed.get("summarise_cooldown") | ||
self.SUMMARISE_MESSAGE_LIMIT: int = parsed.get("summarise_message_limit") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remember when you deploy this to add this to the config file on beryillum |
||
self.MARKOV_ENABLED: bool = parsed.get("markov_enabled") | ||
|
||
# Configuration | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe remove STFU