Skip to content

Commit

Permalink
Refactor Discord message sending functionality
Browse files Browse the repository at this point in the history
tristiisch committed Nov 23, 2023
1 parent 46da8e3 commit fb3068d
Showing 10 changed files with 274 additions and 246 deletions.
74 changes: 52 additions & 22 deletions src/pyramid/connector/discord/bot.py
Original file line number Diff line number Diff line change
@@ -8,11 +8,18 @@
import discord
from discord import (
Guild,
Interaction,
PrivilegedIntentsRequired,
)
from discord.ext.commands import Bot, Context
from discord.ext.commands.errors import CommandNotFound, MissingPermissions, MissingRequiredArgument

from discord.ext.commands.errors import (
CommandNotFound,
MissingPermissions,
MissingRequiredArgument,
CommandError,
)
from discord.app_commands.errors import AppCommandError, CommandInvokeError
from discord.abc import Messageable
from data.functional.application_info import ApplicationInfo
from data.environment import Environment
from data.guild_data import GuildData
@@ -21,16 +28,12 @@
from connector.discord.guild_cmd import GuildCmd
from connector.discord.guild_queue import GuildQueue
from data.functional.engine_source import EngineSource
from data.exceptions import DiscordMessageException
from tools.configuration.configuration import Configuration


class DiscordBot:
def __init__(
self,
logger: logging.Logger,
information: ApplicationInfo,
config: Configuration
):
def __init__(self, logger: logging.Logger, information: ApplicationInfo, config: Configuration):
self.__logger = logger
self.__information = information
self.__token = config.discord__token
@@ -61,7 +64,7 @@ def create(self):
)

@self.bot.event
async def on_command_error(ctx: Context, error):
async def on_command_error(ctx: Context, error: CommandError):
if isinstance(error, CommandNotFound):
await ctx.send("That command didn't exists !")
return
@@ -73,7 +76,7 @@ async def on_command_error(ctx: Context, error):
"You dont have all the requirements or permissions for using this command :angry:"
)
return
logging.error("Command error from on_command_error : %s", error)
logging.error("Command error from on_command_error : %s", error)

@self.bot.event
async def on_error(event, *args, **kwargs):
@@ -82,6 +85,43 @@ async def on_error(event, *args, **kwargs):
logging.error("Error from on_error : %s", traceback.format_exc())
# await bot.send_message(message.channel, "You caused an error!")

async def on_tree_error(ctx: Interaction, app_error: AppCommandError, /):
# if not ctx.response.is_done():
# await ctx.response.defer(thinking=True)

if isinstance(app_error, CommandInvokeError):
msg = ", ".join(app_error.args)
error = app_error.original
else:
msg = "Error from on_tree_error"
error = app_error
trace = "".join(traceback.format_exception(type(error), error, error.__traceback__))
logging.error("%s :\n%s", msg, trace)

discord_explanation = ":warning: You caused an error!"
if isinstance(error, DiscordMessageException):
discord_explanation = str(error)
else:
attributes_dict = vars(ctx.namespace)
formatted_attributes = " ".join(
f"{key}: {value}" for key, value in attributes_dict.items()
)
discord_explanation = (
":warning: An error occurred while processing the command `/%s%s`%s"
% (
ctx.command.name if ctx.command else "<unknown command>",
f" {formatted_attributes}" if formatted_attributes != "" else "",
f"\n```{trace}```" if self.__environment is not Environment.PRODUCTION else "",
)
)

if ctx.response.is_done():
await ctx.followup.send(discord_explanation)
elif isinstance(ctx.channel, Messageable):
await ctx.channel.send(content=f"{ctx.user.mention} {discord_explanation}")

self.bot.tree.on_error = on_tree_error

@self.bot.event
async def on_command(ctx: Context):
logging.debug("on_command : %s", ctx.author)
@@ -99,7 +139,6 @@ async def start(self):
except PrivilegedIntentsRequired as ex:
raise ex


async def stop(self):
# self.bot.clear()
logging.info("Discord bot stop")
@@ -109,23 +148,14 @@ async def stop(self):
def __get_guild_cmd(self, guild: Guild) -> GuildCmd:
if guild.id not in self.guilds_instances:
self.guilds_instances[guild.id] = GuildInstances(
guild,
self.__logger.getChild(guild.name),
self.__engine_source,
self.__ffmpeg
guild, self.__logger.getChild(guild.name), self.__engine_source, self.__ffmpeg
)

return self.guilds_instances[guild.id].cmds


class GuildInstances:
def __init__(
self,
guild: Guild,
logger: Logger,
engine_source: EngineSource,
ffmpeg_path: str
):
def __init__(self, guild: Guild, logger: Logger, engine_source: EngineSource, ffmpeg_path: str):
self.data = GuildData(guild, engine_source)
self.songs = GuildQueue(self.data, ffmpeg_path)
self.cmds = GuildCmd(logger, self.data, self.songs, engine_source)
28 changes: 14 additions & 14 deletions src/pyramid/connector/discord/bot_cmd.py
Original file line number Diff line number Diff line change
@@ -126,7 +126,7 @@ async def cmd_play(ctx: Interaction, input: str):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -137,7 +137,7 @@ async def cmd_play_next(ctx: Interaction, input: str):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -148,7 +148,7 @@ async def cmd_pause(ctx: Interaction):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -159,7 +159,7 @@ async def cmd_resume(ctx: Interaction):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -170,7 +170,7 @@ async def cmd_stop(ctx: Interaction):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -181,7 +181,7 @@ async def cmd_next(ctx: Interaction):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -192,7 +192,7 @@ async def cmd_shuffle(ctx: Interaction):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -203,7 +203,7 @@ async def cmd_remove(ctx: Interaction, number_in_queue: int):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -214,7 +214,7 @@ async def cmd_goto(ctx: Interaction, number_in_queue: int):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -225,7 +225,7 @@ async def cmd_queue(ctx: Interaction):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -239,7 +239,7 @@ async def cmd_search(ctx: Interaction, input: str, engine: str | None):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -252,7 +252,7 @@ async def cmd_play_multiple(ctx: Interaction, input: str):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -265,7 +265,7 @@ async def cmd_play_url(ctx: Interaction, url: str):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

@@ -279,7 +279,7 @@ async def cmd_play_url_next(ctx: Interaction, url: str):
if (await self.__use_on_guild_only(ctx)) is False:
return
ms = MessageSenderQueued(ctx)
await ms.waiting()
await ms.thinking()
guild: Guild = ctx.guild # type: ignore
guild_cmd: GuildCmd = self.__get_guild_cmd(guild)

58 changes: 29 additions & 29 deletions src/pyramid/connector/discord/guild_cmd.py
Original file line number Diff line number Diff line change
@@ -29,11 +29,11 @@ async def play(self, ms: MessageSenderQueued, ctx: Interaction, input: str, at_e
if not voice_channel:
return False

ms.response_message(content=f"Searching **{input}**")
ms.edit_message(f"Searching **{input}**", "search")

track: TrackMinimal | None = self.data.search_engine.default_engine.search_track(input)
if not track:
ms.response_message(content=f"**{input}** not found.")
ms.edit_message("**{input}** not found.", "search")
return False

return await self._execute_play(ms, voice_channel, track, at_end=at_end)
@@ -45,10 +45,10 @@ async def stop(self, ms: MessageSenderQueued, ctx: Interaction) -> bool:

self.data.track_list.clear()
if await self.queue.exit() is False:
ms.response_message(content="The bot does not currently play music")
ms.add_message("The bot does not currently play music")
return False

ms.response_message(content="Music stop")
ms.add_message("Music stop")
return True

async def pause(self, ms: MessageSenderQueued, ctx: Interaction) -> bool:
@@ -57,10 +57,10 @@ async def pause(self, ms: MessageSenderQueued, ctx: Interaction) -> bool:
return False

if self.queue.pause() is False:
ms.response_message(content="The bot does not currently play music")
ms.add_message("The bot does not currently play music")
return False

ms.response_message(content="Music paused")
ms.add_message("Music paused")
return True

async def resume(self, ms: MessageSenderQueued, ctx: Interaction) -> bool:
@@ -69,10 +69,10 @@ async def resume(self, ms: MessageSenderQueued, ctx: Interaction) -> bool:
return False

if self.queue.resume() is False:
ms.response_message(content="The bot is not currently paused")
ms.add_message("The bot is not currently paused")
return False

ms.response_message(content="Music resume")
ms.add_message("Music resume")
return True

async def next(self, ms: MessageSenderQueued, ctx: Interaction) -> bool:
@@ -82,18 +82,18 @@ async def next(self, ms: MessageSenderQueued, ctx: Interaction) -> bool:

if self.queue.has_next() is False:
if self.queue.stop() is False:
ms.response_message(content="The bot does not currently play music")
ms.add_message("The bot does not currently play music")
return False
else:
ms.response_message(content="The bot didn't have next music")
ms.add_message("The bot didn't have next music")
return True

await self.queue.goto_channel(voice_channel)
if self.queue.next() is False:
ms.response_message(content="Unable to play next music")
ms.add_message("Unable to play next music")
return False

ms.response_message(content="Skip musique")
ms.add_message("Skip musique")
return True

async def suffle(self, ms: MessageSenderQueued, ctx: Interaction):
@@ -102,10 +102,10 @@ async def suffle(self, ms: MessageSenderQueued, ctx: Interaction):
return False

if not self.queue.shuffle():
ms.response_message(content="No need to shuffle the queue.")
ms.add_message("No need to shuffle the queue.")
return False

ms.response_message(content="The queue has been shuffled.")
ms.add_message("The queue has been shuffled.")
return True

async def remove(self, ms: MessageSenderQueued, ctx: Interaction, number_in_queue: int):
@@ -114,25 +114,25 @@ async def remove(self, ms: MessageSenderQueued, ctx: Interaction, number_in_queu
return False

if number_in_queue <= 0:
ms.response_message(
ms.add_message(
content=f"Unable to remove element with the number {number_in_queue} in the queue"
)
return False

if number_in_queue == 1:
ms.response_message(
ms.add_message(
content="Unable to remove the current track from the queue. Use `next` instead"
)
return False

track_deleted = self.queue.remove(number_in_queue - 1)
if track_deleted is None:
ms.response_message(
ms.add_message(
content=f"There is no element with the number {number_in_queue} in the queue"
)
return False

ms.response_message(
ms.add_message(
content=f"**{track_deleted.get_full_name()}** has been removed from queue"
)
return True
@@ -143,32 +143,32 @@ async def goto(self, ms: MessageSenderQueued, ctx: Interaction, number_in_queue:
return False

if number_in_queue <= 0:
ms.response_message(
ms.add_message(
content=f"Unable to go to element with number {number_in_queue} in the queue"
)
return False

if number_in_queue == 1:
ms.response_message(
ms.add_message(
content="Unable to remove the current track from the queue. Use `next` instead"
)
return False

tracks_removed = self.queue.goto(number_in_queue - 1)
if tracks_removed <= 0:
ms.response_message(
ms.add_message(
content=f"There is no element with the number {number_in_queue} in the queue"
)
return False

# +1 for current track
ms.response_message(content=f"f{tracks_removed + 1} tracks has been skipped")
ms.add_message(f"f{tracks_removed + 1} tracks has been skipped")
return True

def queue_list(self, ms: MessageSenderQueued, ctx: Interaction) -> bool:
queue: str | None = self.queue.queue_list()
if queue is None:
ms.response_message(content="Queue is empty")
ms.add_message("Queue is empty")
return False

ms.add_code_message(queue, prefix="Here's the music in the queue :")
@@ -182,15 +182,15 @@ def search(
else:
test_value = self.data.search_engines.get_engine(engine)
if not test_value:
ms.response_message(content=f"Search engine **{engine}** not found.")
ms.add_message(f"Search engine **{engine}** not found.")
return False
else:
search_engine = test_value

result: list[TrackMinimal] | None = search_engine.search_tracks(input)

if not result:
ms.response_message(content=f"**{input}** not found.")
ms.add_message(f"**{input}** not found.")
return False

hsa = utils_list_track.to_str.tracks(result)
@@ -202,11 +202,11 @@ async def play_multiple(self, ms: MessageSenderQueued, ctx: Interaction, input:
if not voice_channel:
return False

ms.response_message(content=f"Searching **{input}** ...")
ms.edit_message(f"Searching **{input}** ...", "search")

tracks: list[TrackMinimal] | None = self.data.search_engine.default_engine.search_tracks(input)
if not tracks:
ms.response_message(content=f"**{input}** not found.")
ms.edit_message(f"**{input}** not found.", "search")
return False

return await self._execute_play_multiple(ms, voice_channel, tracks)
@@ -216,13 +216,13 @@ async def play_url(self, ms: MessageSenderQueued, ctx: Interaction, url: str, at
if not voice_channel:
return False

ms.response_message(content=f"Searching **{url}** ...")
ms.edit_message(f"Searching **{url}** ...", "search")

result: (
tuple[list[TrackMinimal], list[TrackMinimal]] | TrackMinimal | None
) = await self.data.search_engine.search_by_url(url)
if not result:
ms.response_message(content=f"**{url}** not found.")
ms.edit_message(f"**{url}** not found.", "search")
return False

if isinstance(result, tuple):
39 changes: 18 additions & 21 deletions src/pyramid/connector/discord/guild_cmd_tools.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ async def _verify_voice_channel(

# only play music if user is in a voice channel
if member.voice is None:
ms.response_message(content="You're not in a channel.")
ms.edit_message("You're not in a channel.")
return None
voice_state: VoiceState = member.voice

@@ -42,7 +42,7 @@ async def _verify_voice_channel(
bot: Member = self.data.guild.me
permissions = voice_channel.permissions_for(bot)
if not permissions.connect:
ms.response_message(content=f"I can't go to {voice_channel.mention}")
ms.edit_message(f"I can't go to {voice_channel.mention}")
return None

if not permissions.speak:
@@ -54,7 +54,7 @@ async def _verify_bot_channel(self, ms: MessageSenderQueued, channel: VoiceChann
vc: VoiceClient = self.data.voice_client

if vc.channel.id != channel.id:
ms.response_message(content="You're not in the bot channel.")
ms.edit_message("You're not in the bot channel.")
return False
return True

@@ -87,7 +87,7 @@ async def _execute_play_multiple(
ms.add_message(content=f"Can't find the audio for this track:\n* {out}")

length = len(tracks)
ms.response_message(content=f"Downloading ... 0/{length}")
ms.edit_message(f"Downloading ... 0/{length}", "download")

cant_dl = 0
for i, track in enumerate(tracks):
@@ -96,54 +96,51 @@ async def _execute_play_multiple(
ms.add_message(content=f"ERROR > **{track.get_full_name()}** can't be downloaded.")
cant_dl += 1
continue
if (
at_end is True
and not (tl.add_track(track_downloaded)
or tl.add_track_after(track_downloaded))
if at_end is True and not (
tl.add_track(track_downloaded) or tl.add_track_after(track_downloaded)
):
ms.add_message(
content=f"ERROR > **{track.get_full_name()}** can't be add to the queue."
)
cant_dl += 1
continue
ms.response_message(
content=f"Downloading ... {i + 1 - cant_dl}/{length - cant_dl}"
)
ms.edit_message(f"Downloading ... {i + 1 - cant_dl}/{length - cant_dl}", "download")
if i == 0:
await self.queue.goto_channel(voice_channel)
await self.queue.play(ms)

if length == cant_dl:
ms.response_message(content="None of the music could be downloaded")
ms.edit_message("None of the music could be downloaded", "download")
return False

await self.queue.goto_channel(voice_channel)

if await self.queue.play(ms) is False:
ms.response_message(content=f"**{length}** tracks have been added to the queue")
ms.edit_message(f"**{length}** tracks have been added to the queue", "download")
return True

async def _execute_play(
self, ms: MessageSenderQueued, voice_channel: VoiceChannel, track: TrackMinimal, at_end=True
) -> bool:
tl: TrackList = self.data.track_list
ms.response_message(content=f"**{track.get_full_name()}** found ! Downloading ...")
ms.edit_message(f"**{track.get_full_name()}** found ! Downloading ...", "download")

track_downloaded: Track | None = await self.engine_source.download_track(track)
if not track_downloaded:
ms.response_message(content=f"ERROR > **{track.get_full_name()}** can't be downloaded.")
ms.add_message(f"ERROR > **{track.get_full_name()}** can't be downloaded.")
return False

if (
(at_end is True and not tl.add_track(track_downloaded))
or not tl.add_track_after(track_downloaded)
if (at_end is True and not tl.add_track(track_downloaded)) or not tl.add_track_after(
track_downloaded
):
ms.add_message(content=f"ERROR > **{track.get_full_name()}** can't be add to the queue.")
ms.add_message(
content=f"ERROR > **{track.get_full_name()}** can't be add to the queue."
)
return False
await self.queue.goto_channel(voice_channel)

if await self.queue.play(ms) is False:
ms.response_message(content=f"**{track.get_full_name()}** is added to the queue")
ms.edit_message(f"**{track.get_full_name()}** is added to the queue", "download")
else:
ms.response_message(content=f"Playing **{track.get_full_name()}**")
ms.edit_message(f"Playing **{track.get_full_name()}**", "download")
return True
2 changes: 1 addition & 1 deletion src/pyramid/connector/discord/guild_queue.py
Original file line number Diff line number Diff line change
@@ -180,7 +180,7 @@ async def __song_end_continue(self, err: Exception | None, msg_sender: MessageSe

if tl.is_empty():
await vc.disconnect()
msg_sender.response_message(content="Bye bye")
msg_sender.add_message("Bye bye")
else:
await self.play(msg_sender)

17 changes: 11 additions & 6 deletions src/pyramid/connector/discord/music_player_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import discord
from discord import Embed, Locale, Message, TextChannel
from discord import Embed, Locale, Message
from discord.abc import Messageable

from data.track import Track
from data.tracklist import TrackList
@@ -11,16 +12,20 @@ def __init__(self, locale: Locale, track_list: TrackList):
self.last_message: Message | None = None
self.track_list = track_list

async def send_player(self, txt_channel: TextChannel, track: Track):
async def send_player(self, txt_channel: Messageable, track: Track):
embed = self.__embed_track(track)

if self.last_message is not None:
if txt_channel.last_message_id == self.last_message.id:
self.last_message = await self.last_message.edit(content="", embed=embed)
last_channel_message = None
history = txt_channel.history(limit=1)
async for message in history:
last_channel_message = message

if last_channel_message and self.last_message is not None:
if last_channel_message.id == self.last_message.id:
self.last_message = await last_channel_message.edit(content="", embed=embed)
return
else:
await self.last_message.delete()
# await self.last_message.edit(content="", suppress=True)
self.last_message = await txt_channel.send(content="", embed=embed)

def __embed_track(self, track: Track) -> Embed:
9 changes: 9 additions & 0 deletions src/pyramid/data/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class CustomException(Exception):
def __init__(self, *args: object):
super().__init__(*args)

class DiscordMessageException(CustomException):
def __init__(self, *args: object):
msg = str(args[0]) % args[1:]
super().__init__(msg)

2 changes: 1 addition & 1 deletion src/pyramid/data/functional/application_info.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ class ApplicationInfo:
def __init__(self):
self.name = "pyramid"
self.os = get_os().lower()
self.version = "0.3.2"
self.version = "0.3.3"
self.git_info = GitInfo()

def load_git_info(self):
216 changes: 124 additions & 92 deletions src/pyramid/data/functional/messages/message_sender.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
import logging
from typing import Callable

import tools.utils as tools
from discord import Interaction, Message, TextChannel, WebhookMessage
from discord.errors import HTTPException
from discord import Interaction, Message, WebhookMessage
from discord.utils import MISSING
from discord.abc import Messageable

MAX_MSG_LENGTH = 2000

@@ -15,100 +13,83 @@ def __init__(self, ctx: Interaction):
self.ctx = ctx
if ctx.channel is None:
raise NotImplementedError("Unable to create a MessageSender without channel")
if not isinstance(ctx.channel, TextChannel):
if not isinstance(ctx.channel, Messageable):
raise NotImplementedError("Unable to create a MessageSender without text channel")
self.txt_channel: TextChannel = ctx.channel
self.last_reponse: Message | WebhookMessage | None = None
self.txt_channel: Messageable = ctx.channel
self.last_message: Message | WebhookMessage | None = None
self.last_message_surname: dict[str, Message | WebhookMessage] = dict()
self.loop: asyncio.AbstractEventLoop = ctx.client.loop
self.think = False

async def thinking(self):
await self.ctx.response.defer(thinking=True)
self.think = True

async def add_message(
# def add_message(
self,
content: str = MISSING,
callback: Callable | None = None,
# ) -> None:
) -> Message:
content: str = MISSING
) -> Message | WebhookMessage:
"""
Add a message as a response or follow-up. If no message has been sent yet, the message is sent as a response.
Otherwise, the message will be linked to the response (sent as a follow-up message).
If the message exceeds the maximum character limit, it will be truncated.
"""
if content != MISSING and content != "":
new_content, is_used = tools.substring_with_end_msg(
content, MAX_MSG_LENGTH, "{} more characters..."
)
if is_used:
content = new_content
return await self.__add_message(content)

async def __add_message(
self,
content: str = MISSING
) -> Message | WebhookMessage:

if not self.ctx.response.is_done():
msg = await self.txt_channel.send(content)
# If a reply has already been sent
if self.last_message:
last_message = await self._send_after_last_msg(content)

# First reply
else:
msg = await self.ctx.followup.send(
content,
wait=True,
)
return msg
last_message = await self._send_as_first_reply(content)

self.last_message = last_message
return last_message

async def response_message(
# def response_message(
async def edit_message(
self,
content: str = MISSING,
):
surname_content: str | None = None,
) -> Message | WebhookMessage:
"""
Send a message as a response. If the response has already been sent, it will be modified.
If it is not possible to modify it, a new message will be sent as a follow-up.
If the message exceeds the maximum character limit, it will be truncated.
"""
if content != MISSING and content != "":
new_content, is_used = tools.substring_with_end_msg(
content, MAX_MSG_LENGTH, "{} more characters..."
)
if is_used:
content = new_content

if self.last_reponse is not None:
self.last_reponse.edit(content=content)
# If a reply has already been sent
if self.last_message:

elif self.ctx.response.is_done():
try:
await self.ctx.edit_original_response(
content=content,
)
# queue.add(
# QueueItem(
# "Edit response",
# self.__ctx.edit_original_response,
# self.loop,
# content=content,
# )
# )
except HTTPException as err:
if err.code == 50027: # 401 Unauthorized : Invalid Webhook Token
logging.warning(
"Unable to modify original response, send message instead", exc_info=True
)
# self.last_reponse = await self.add_message(content)
await self.add_message(content, lambda msg: setattr(self, "last_response", msg))
# If the message has a nickname, only the last reply with the same nickname is changed.
if surname_content:
last_message_edited = self.last_message_surname.get(surname_content)

if last_message_edited:
msg = await last_message_edited.edit(content=self._tuncate_msg_if_overflow(content))
self.last_message_surname[surname_content] = await last_message_edited.edit(content=self._tuncate_msg_if_overflow(content))
else:
raise err
msg = await self.__add_message(content)
self.last_message_surname[surname_content] = msg
else:
msg = await self.last_message.edit(content=self._tuncate_msg_if_overflow(content))

# First reply
else:
await self.ctx.response.send_message(
content=content,
)
# queue.add(
# QueueItem(
# "Send followup as response",
# self.__ctx.response.send_message,
# self.loop,
# content=content,
# )
# )
msg = await self._send_as_first_reply(content)
if surname_content:
self.last_message_surname[surname_content] = msg

return msg

async def add_code_message(self, content: str, prefix=None, suffix=None):
"""
Send a message with markdown code formatting. If the character limit is exceeded, send multiple messages.
"""
# def add_code_message(self, content: str, prefix=None, suffix=None):
max_length = MAX_MSG_LENGTH
if prefix is None:
prefix = "```"
@@ -123,28 +104,79 @@ async def add_code_message(self, content: str, prefix=None, suffix=None):

substrings_generator = tools.split_string_by_length(content, max_length)

if not self.ctx.response.is_done():
first_substring = next(substrings_generator, None)
if first_substring is not None:
first_substring_formatted = f"```{first_substring}```"
await self.ctx.response.send_message(content=first_substring_formatted)
# queue.add(
# QueueItem(
# "Send code as response",
# self.__ctx.response.send_message,
# self.loop,
# content=first_substring_formatted,
# )
# )
first_substring = next(substrings_generator, None)
if first_substring is None:
return

first_substring_formatted = f"```{first_substring}```"
if not self.ctx.is_expired() and self.ctx.response.is_done():
self.last_message = await self.ctx.followup.send(first_substring_formatted, wait=True)
else:
self.last_message = await self.txt_channel.send(first_substring_formatted)

last_channel_message = None
for substring in substrings_generator:
substring_formatted = f"```{substring}```"
await self.ctx.followup.send(content=substring_formatted)
# queue.add(
# QueueItem(
# "Send code as followup",
# self.__ctx.followup.send,
# self.loop,
# content=substring_formatted,
# )
# )

last_channel_message = await self._get_last_channel_message()

# If the last message of channel is the last reply
if last_channel_message.id == self.last_message.id:
self.last_message = await self.txt_channel.send(content=substring_formatted)
else:
self.last_message = await self.ctx.followup.send(
content=substring_formatted, wait=True
)

async def _send_as_first_reply(self, content: str) -> Message | WebhookMessage:
# If interaction can be used
if not self.ctx.is_expired() and self.ctx.response.is_done():
last_message = await self.ctx.followup.send(
self._tuncate_msg_if_overflow(content), wait=True
)
else:
command_name = self.ctx.command.name if self.ctx.command else "<unknown command>"
last_message = await self.txt_channel.send(
self._tuncate_msg_if_overflow(
f"{self.ctx.user.mention} `/{command_name}` {content}"
)
)
self.last_message = last_message
return last_message

async def _send_after_last_msg(self, content: str) -> Message | WebhookMessage:
if not self.last_message:
raise Exception("There is no last message")

last_channel_message = await self._get_last_channel_message()

# If the last message of channel is the last reply
if last_channel_message.id == self.last_message.id:
last_message = await self.txt_channel.send(self._tuncate_msg_if_overflow(content))

# If not, send a message linked to the last message
else:
last_message = await self.last_message.reply(
self._tuncate_msg_if_overflow(content)
)
self.last_message = last_message
return last_message

async def _get_last_channel_message(self) -> Message:
last_channel_message = None
history = self.txt_channel.history(limit=1)
async for message in history:
last_channel_message = message
if not last_channel_message:
raise Exception("Channel didn't have history")
return last_channel_message

def _tuncate_msg_if_overflow(self, content: str) -> str:
if content == MISSING or content == "":
return content
new_content, is_used = tools.substring_with_end_msg(
content, MAX_MSG_LENGTH, "{} more characters..."
)
if not is_used:
return content
return new_content
75 changes: 15 additions & 60 deletions src/pyramid/data/functional/messages/message_sender_queued.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import logging
from typing import Callable
from typing import Any, Callable

import tools.utils as tools
from data.functional.messages.message_sender import MessageSender
from discord import Interaction
from discord import Interaction, Message, WebhookMessage
from discord.utils import MISSING
from tools.queue import Queue, QueueItem

@@ -19,73 +17,30 @@ def __init__(self, ctx: Interaction):
self.ctx = ctx
super().__init__(ctx)

async def waiting(self):
await super().response_message("Waiting for result ...")

def add_message(
self,
content: str = MISSING,
callback: Callable | None = None,
callback: Callable[[Message | WebhookMessage], Any] | None = None,
) -> None:
queue.add(
QueueItem(
"add_message", super().add_message, self.loop, content=content, callback=callback
)
QueueItem("add_message", super().add_message, self.loop, callback, content=content)
)

def response_message(
def edit_message(
self,
content: str = MISSING,
surname_content: str | None = None,
callback: Callable[[Message | WebhookMessage], Any] | None = None,
):
if content != MISSING and content != "":
new_content, is_used = tools.substring_with_end_msg(
content, MAX_MSG_LENGTH, "{} more characters..."
)
if is_used:
content = new_content

if self.last_reponse is not None:
queue.add(
QueueItem(
"Edit last response",
self.last_reponse.edit,
self.loop,
content=content,
)
)

elif self.ctx.response.is_done():
def on_error(err):
if err.code == 50027: # 401 Unauthorized : Invalid Webhook Token
logging.warning(
"Unable to modify original response, send message instead", exc_info=True
)
self.add_message(content, lambda msg: setattr(self, "last_response", msg))
else:
raise err

queue.add(
QueueItem(
"Edit response",
self.ctx.edit_original_response,
self.loop,
None,
on_error,
content=content,
)
)
else:
queue.add(
QueueItem(
"Send followup as response",
self.ctx.response.send_message,
self.loop,
content=content,
)
)

queue.add(
QueueItem("response_message", super().response_message, self.loop, content=content)
QueueItem(
"response_message",
super().edit_message,
self.loop,
callback,
content=content,
surname_content=surname_content,
)
)

def add_code_message(self, content: str, prefix=None, suffix=None):

0 comments on commit fb3068d

Please sign in to comment.