diff --git a/src/pyramid/connector/discord/bot.py b/src/pyramid/connector/discord/bot.py index 1db6325..0212b19 100644 --- a/src/pyramid/connector/discord/bot.py +++ b/src/pyramid/connector/discord/bot.py @@ -8,11 +8,18 @@ import discord from discord import ( Guild, + Interaction, PrivilegedIntentsRequired, ) from discord.ext.commands import Bot, Context -from discord.ext.commands.errors import CommandNotFound, MissingPermissions, MissingRequiredArgument - +from discord.ext.commands.errors import ( + CommandNotFound, + MissingPermissions, + MissingRequiredArgument, + CommandError, +) +from discord.app_commands.errors import AppCommandError, CommandInvokeError +from discord.abc import Messageable from data.functional.application_info import ApplicationInfo from data.environment import Environment from data.guild_data import GuildData @@ -21,16 +28,12 @@ from connector.discord.guild_cmd import GuildCmd from connector.discord.guild_queue import GuildQueue from data.functional.engine_source import EngineSource +from data.exceptions import DiscordMessageException from tools.configuration.configuration import Configuration class DiscordBot: - def __init__( - self, - logger: logging.Logger, - information: ApplicationInfo, - config: Configuration - ): + def __init__(self, logger: logging.Logger, information: ApplicationInfo, config: Configuration): self.__logger = logger self.__information = information self.__token = config.discord__token @@ -61,7 +64,7 @@ def create(self): ) @self.bot.event - async def on_command_error(ctx: Context, error): + async def on_command_error(ctx: Context, error: CommandError): if isinstance(error, CommandNotFound): await ctx.send("That command didn't exists !") return @@ -73,7 +76,7 @@ async def on_command_error(ctx: Context, error): "You dont have all the requirements or permissions for using this command :angry:" ) return - logging.error("Command error from on_command_error : %s", error) + logging.error("Command error from on_command_error : %s", error) @self.bot.event async def on_error(event, *args, **kwargs): @@ -82,6 +85,43 @@ async def on_error(event, *args, **kwargs): logging.error("Error from on_error : %s", traceback.format_exc()) # await bot.send_message(message.channel, "You caused an error!") + async def on_tree_error(ctx: Interaction, app_error: AppCommandError, /): + # if not ctx.response.is_done(): + # await ctx.response.defer(thinking=True) + + if isinstance(app_error, CommandInvokeError): + msg = ", ".join(app_error.args) + error = app_error.original + else: + msg = "Error from on_tree_error" + error = app_error + trace = "".join(traceback.format_exception(type(error), error, error.__traceback__)) + logging.error("%s :\n%s", msg, trace) + + discord_explanation = ":warning: You caused an error!" + if isinstance(error, DiscordMessageException): + discord_explanation = str(error) + else: + attributes_dict = vars(ctx.namespace) + formatted_attributes = " ".join( + f"{key}: {value}" for key, value in attributes_dict.items() + ) + discord_explanation = ( + ":warning: An error occurred while processing the command `/%s%s`%s" + % ( + ctx.command.name if ctx.command else "", + f" {formatted_attributes}" if formatted_attributes != "" else "", + f"\n```{trace}```" if self.__environment is not Environment.PRODUCTION else "", + ) + ) + + if ctx.response.is_done(): + await ctx.followup.send(discord_explanation) + elif isinstance(ctx.channel, Messageable): + await ctx.channel.send(content=f"{ctx.user.mention} {discord_explanation}") + + self.bot.tree.on_error = on_tree_error + @self.bot.event async def on_command(ctx: Context): logging.debug("on_command : %s", ctx.author) @@ -99,7 +139,6 @@ async def start(self): except PrivilegedIntentsRequired as ex: raise ex - async def stop(self): # self.bot.clear() logging.info("Discord bot stop") @@ -109,23 +148,14 @@ async def stop(self): def __get_guild_cmd(self, guild: Guild) -> GuildCmd: if guild.id not in self.guilds_instances: self.guilds_instances[guild.id] = GuildInstances( - guild, - self.__logger.getChild(guild.name), - self.__engine_source, - self.__ffmpeg + guild, self.__logger.getChild(guild.name), self.__engine_source, self.__ffmpeg ) return self.guilds_instances[guild.id].cmds class GuildInstances: - def __init__( - self, - guild: Guild, - logger: Logger, - engine_source: EngineSource, - ffmpeg_path: str - ): + def __init__(self, guild: Guild, logger: Logger, engine_source: EngineSource, ffmpeg_path: str): self.data = GuildData(guild, engine_source) self.songs = GuildQueue(self.data, ffmpeg_path) self.cmds = GuildCmd(logger, self.data, self.songs, engine_source) diff --git a/src/pyramid/connector/discord/bot_cmd.py b/src/pyramid/connector/discord/bot_cmd.py index 06d4ef2..b39a8f2 100644 --- a/src/pyramid/connector/discord/bot_cmd.py +++ b/src/pyramid/connector/discord/bot_cmd.py @@ -126,7 +126,7 @@ async def cmd_play(ctx: Interaction, input: str): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -137,7 +137,7 @@ async def cmd_play_next(ctx: Interaction, input: str): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -148,7 +148,7 @@ async def cmd_pause(ctx: Interaction): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -159,7 +159,7 @@ async def cmd_resume(ctx: Interaction): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -170,7 +170,7 @@ async def cmd_stop(ctx: Interaction): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -181,7 +181,7 @@ async def cmd_next(ctx: Interaction): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -192,7 +192,7 @@ async def cmd_shuffle(ctx: Interaction): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -203,7 +203,7 @@ async def cmd_remove(ctx: Interaction, number_in_queue: int): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -214,7 +214,7 @@ async def cmd_goto(ctx: Interaction, number_in_queue: int): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -225,7 +225,7 @@ async def cmd_queue(ctx: Interaction): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -239,7 +239,7 @@ async def cmd_search(ctx: Interaction, input: str, engine: str | None): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -252,7 +252,7 @@ async def cmd_play_multiple(ctx: Interaction, input: str): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -265,7 +265,7 @@ async def cmd_play_url(ctx: Interaction, url: str): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) @@ -279,7 +279,7 @@ async def cmd_play_url_next(ctx: Interaction, url: str): if (await self.__use_on_guild_only(ctx)) is False: return ms = MessageSenderQueued(ctx) - await ms.waiting() + await ms.thinking() guild: Guild = ctx.guild # type: ignore guild_cmd: GuildCmd = self.__get_guild_cmd(guild) diff --git a/src/pyramid/connector/discord/guild_cmd.py b/src/pyramid/connector/discord/guild_cmd.py index ddda32b..ae3e6a4 100644 --- a/src/pyramid/connector/discord/guild_cmd.py +++ b/src/pyramid/connector/discord/guild_cmd.py @@ -29,11 +29,11 @@ async def play(self, ms: MessageSenderQueued, ctx: Interaction, input: str, at_e if not voice_channel: return False - ms.response_message(content=f"Searching **{input}**") + ms.edit_message(f"Searching **{input}**", "search") track: TrackMinimal | None = self.data.search_engine.default_engine.search_track(input) if not track: - ms.response_message(content=f"**{input}** not found.") + ms.edit_message("**{input}** not found.", "search") return False return await self._execute_play(ms, voice_channel, track, at_end=at_end) @@ -45,10 +45,10 @@ async def stop(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: self.data.track_list.clear() if await self.queue.exit() is False: - ms.response_message(content="The bot does not currently play music") + ms.add_message("The bot does not currently play music") return False - ms.response_message(content="Music stop") + ms.add_message("Music stop") return True async def pause(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: @@ -57,10 +57,10 @@ async def pause(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: return False if self.queue.pause() is False: - ms.response_message(content="The bot does not currently play music") + ms.add_message("The bot does not currently play music") return False - ms.response_message(content="Music paused") + ms.add_message("Music paused") return True async def resume(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: @@ -69,10 +69,10 @@ async def resume(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: return False if self.queue.resume() is False: - ms.response_message(content="The bot is not currently paused") + ms.add_message("The bot is not currently paused") return False - ms.response_message(content="Music resume") + ms.add_message("Music resume") return True async def next(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: @@ -82,18 +82,18 @@ async def next(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: if self.queue.has_next() is False: if self.queue.stop() is False: - ms.response_message(content="The bot does not currently play music") + ms.add_message("The bot does not currently play music") return False else: - ms.response_message(content="The bot didn't have next music") + ms.add_message("The bot didn't have next music") return True await self.queue.goto_channel(voice_channel) if self.queue.next() is False: - ms.response_message(content="Unable to play next music") + ms.add_message("Unable to play next music") return False - ms.response_message(content="Skip musique") + ms.add_message("Skip musique") return True async def suffle(self, ms: MessageSenderQueued, ctx: Interaction): @@ -102,10 +102,10 @@ async def suffle(self, ms: MessageSenderQueued, ctx: Interaction): return False if not self.queue.shuffle(): - ms.response_message(content="No need to shuffle the queue.") + ms.add_message("No need to shuffle the queue.") return False - ms.response_message(content="The queue has been shuffled.") + ms.add_message("The queue has been shuffled.") return True async def remove(self, ms: MessageSenderQueued, ctx: Interaction, number_in_queue: int): @@ -114,25 +114,25 @@ async def remove(self, ms: MessageSenderQueued, ctx: Interaction, number_in_queu return False if number_in_queue <= 0: - ms.response_message( + ms.add_message( content=f"Unable to remove element with the number {number_in_queue} in the queue" ) return False if number_in_queue == 1: - ms.response_message( + ms.add_message( content="Unable to remove the current track from the queue. Use `next` instead" ) return False track_deleted = self.queue.remove(number_in_queue - 1) if track_deleted is None: - ms.response_message( + ms.add_message( content=f"There is no element with the number {number_in_queue} in the queue" ) return False - ms.response_message( + ms.add_message( content=f"**{track_deleted.get_full_name()}** has been removed from queue" ) return True @@ -143,32 +143,32 @@ async def goto(self, ms: MessageSenderQueued, ctx: Interaction, number_in_queue: return False if number_in_queue <= 0: - ms.response_message( + ms.add_message( content=f"Unable to go to element with number {number_in_queue} in the queue" ) return False if number_in_queue == 1: - ms.response_message( + ms.add_message( content="Unable to remove the current track from the queue. Use `next` instead" ) return False tracks_removed = self.queue.goto(number_in_queue - 1) if tracks_removed <= 0: - ms.response_message( + ms.add_message( content=f"There is no element with the number {number_in_queue} in the queue" ) return False # +1 for current track - ms.response_message(content=f"f{tracks_removed + 1} tracks has been skipped") + ms.add_message(f"f{tracks_removed + 1} tracks has been skipped") return True def queue_list(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: queue: str | None = self.queue.queue_list() if queue is None: - ms.response_message(content="Queue is empty") + ms.add_message("Queue is empty") return False ms.add_code_message(queue, prefix="Here's the music in the queue :") @@ -182,7 +182,7 @@ def search( else: test_value = self.data.search_engines.get_engine(engine) if not test_value: - ms.response_message(content=f"Search engine **{engine}** not found.") + ms.add_message(f"Search engine **{engine}** not found.") return False else: search_engine = test_value @@ -190,7 +190,7 @@ def search( result: list[TrackMinimal] | None = search_engine.search_tracks(input) if not result: - ms.response_message(content=f"**{input}** not found.") + ms.add_message(f"**{input}** not found.") return False hsa = utils_list_track.to_str.tracks(result) @@ -202,11 +202,11 @@ async def play_multiple(self, ms: MessageSenderQueued, ctx: Interaction, input: if not voice_channel: return False - ms.response_message(content=f"Searching **{input}** ...") + ms.edit_message(f"Searching **{input}** ...", "search") tracks: list[TrackMinimal] | None = self.data.search_engine.default_engine.search_tracks(input) if not tracks: - ms.response_message(content=f"**{input}** not found.") + ms.edit_message(f"**{input}** not found.", "search") return False return await self._execute_play_multiple(ms, voice_channel, tracks) @@ -216,13 +216,13 @@ async def play_url(self, ms: MessageSenderQueued, ctx: Interaction, url: str, at if not voice_channel: return False - ms.response_message(content=f"Searching **{url}** ...") + ms.edit_message(f"Searching **{url}** ...", "search") result: ( tuple[list[TrackMinimal], list[TrackMinimal]] | TrackMinimal | None ) = await self.data.search_engine.search_by_url(url) if not result: - ms.response_message(content=f"**{url}** not found.") + ms.edit_message(f"**{url}** not found.", "search") return False if isinstance(result, tuple): diff --git a/src/pyramid/connector/discord/guild_cmd_tools.py b/src/pyramid/connector/discord/guild_cmd_tools.py index a314af1..a72b9e5 100644 --- a/src/pyramid/connector/discord/guild_cmd_tools.py +++ b/src/pyramid/connector/discord/guild_cmd_tools.py @@ -29,7 +29,7 @@ async def _verify_voice_channel( # only play music if user is in a voice channel if member.voice is None: - ms.response_message(content="You're not in a channel.") + ms.edit_message("You're not in a channel.") return None voice_state: VoiceState = member.voice @@ -42,7 +42,7 @@ async def _verify_voice_channel( bot: Member = self.data.guild.me permissions = voice_channel.permissions_for(bot) if not permissions.connect: - ms.response_message(content=f"I can't go to {voice_channel.mention}") + ms.edit_message(f"I can't go to {voice_channel.mention}") return None if not permissions.speak: @@ -54,7 +54,7 @@ async def _verify_bot_channel(self, ms: MessageSenderQueued, channel: VoiceChann vc: VoiceClient = self.data.voice_client if vc.channel.id != channel.id: - ms.response_message(content="You're not in the bot channel.") + ms.edit_message("You're not in the bot channel.") return False return True @@ -87,7 +87,7 @@ async def _execute_play_multiple( ms.add_message(content=f"Can't find the audio for this track:\n* {out}") length = len(tracks) - ms.response_message(content=f"Downloading ... 0/{length}") + ms.edit_message(f"Downloading ... 0/{length}", "download") cant_dl = 0 for i, track in enumerate(tracks): @@ -96,54 +96,51 @@ async def _execute_play_multiple( ms.add_message(content=f"ERROR > **{track.get_full_name()}** can't be downloaded.") cant_dl += 1 continue - if ( - at_end is True - and not (tl.add_track(track_downloaded) - or tl.add_track_after(track_downloaded)) + if at_end is True and not ( + tl.add_track(track_downloaded) or tl.add_track_after(track_downloaded) ): ms.add_message( content=f"ERROR > **{track.get_full_name()}** can't be add to the queue." ) cant_dl += 1 continue - ms.response_message( - content=f"Downloading ... {i + 1 - cant_dl}/{length - cant_dl}" - ) + ms.edit_message(f"Downloading ... {i + 1 - cant_dl}/{length - cant_dl}", "download") if i == 0: await self.queue.goto_channel(voice_channel) await self.queue.play(ms) if length == cant_dl: - ms.response_message(content="None of the music could be downloaded") + ms.edit_message("None of the music could be downloaded", "download") return False await self.queue.goto_channel(voice_channel) if await self.queue.play(ms) is False: - ms.response_message(content=f"**{length}** tracks have been added to the queue") + ms.edit_message(f"**{length}** tracks have been added to the queue", "download") return True async def _execute_play( self, ms: MessageSenderQueued, voice_channel: VoiceChannel, track: TrackMinimal, at_end=True ) -> bool: tl: TrackList = self.data.track_list - ms.response_message(content=f"**{track.get_full_name()}** found ! Downloading ...") + ms.edit_message(f"**{track.get_full_name()}** found ! Downloading ...", "download") track_downloaded: Track | None = await self.engine_source.download_track(track) if not track_downloaded: - ms.response_message(content=f"ERROR > **{track.get_full_name()}** can't be downloaded.") + ms.add_message(f"ERROR > **{track.get_full_name()}** can't be downloaded.") return False - if ( - (at_end is True and not tl.add_track(track_downloaded)) - or not tl.add_track_after(track_downloaded) + if (at_end is True and not tl.add_track(track_downloaded)) or not tl.add_track_after( + track_downloaded ): - ms.add_message(content=f"ERROR > **{track.get_full_name()}** can't be add to the queue.") + ms.add_message( + content=f"ERROR > **{track.get_full_name()}** can't be add to the queue." + ) return False await self.queue.goto_channel(voice_channel) if await self.queue.play(ms) is False: - ms.response_message(content=f"**{track.get_full_name()}** is added to the queue") + ms.edit_message(f"**{track.get_full_name()}** is added to the queue", "download") else: - ms.response_message(content=f"Playing **{track.get_full_name()}**") + ms.edit_message(f"Playing **{track.get_full_name()}**", "download") return True diff --git a/src/pyramid/connector/discord/guild_queue.py b/src/pyramid/connector/discord/guild_queue.py index e833aad..e08bdc3 100644 --- a/src/pyramid/connector/discord/guild_queue.py +++ b/src/pyramid/connector/discord/guild_queue.py @@ -180,7 +180,7 @@ async def __song_end_continue(self, err: Exception | None, msg_sender: MessageSe if tl.is_empty(): await vc.disconnect() - msg_sender.response_message(content="Bye bye") + msg_sender.add_message("Bye bye") else: await self.play(msg_sender) diff --git a/src/pyramid/connector/discord/music_player_interface.py b/src/pyramid/connector/discord/music_player_interface.py index 43fe341..f5c5137 100644 --- a/src/pyramid/connector/discord/music_player_interface.py +++ b/src/pyramid/connector/discord/music_player_interface.py @@ -1,5 +1,6 @@ import discord -from discord import Embed, Locale, Message, TextChannel +from discord import Embed, Locale, Message +from discord.abc import Messageable from data.track import Track from data.tracklist import TrackList @@ -11,16 +12,20 @@ def __init__(self, locale: Locale, track_list: TrackList): self.last_message: Message | None = None self.track_list = track_list - async def send_player(self, txt_channel: TextChannel, track: Track): + async def send_player(self, txt_channel: Messageable, track: Track): embed = self.__embed_track(track) - if self.last_message is not None: - if txt_channel.last_message_id == self.last_message.id: - self.last_message = await self.last_message.edit(content="", embed=embed) + last_channel_message = None + history = txt_channel.history(limit=1) + async for message in history: + last_channel_message = message + + if last_channel_message and self.last_message is not None: + if last_channel_message.id == self.last_message.id: + self.last_message = await last_channel_message.edit(content="", embed=embed) return else: await self.last_message.delete() - # await self.last_message.edit(content="", suppress=True) self.last_message = await txt_channel.send(content="", embed=embed) def __embed_track(self, track: Track) -> Embed: diff --git a/src/pyramid/data/exceptions.py b/src/pyramid/data/exceptions.py new file mode 100644 index 0000000..c398b8c --- /dev/null +++ b/src/pyramid/data/exceptions.py @@ -0,0 +1,9 @@ +class CustomException(Exception): + def __init__(self, *args: object): + super().__init__(*args) + +class DiscordMessageException(CustomException): + def __init__(self, *args: object): + msg = str(args[0]) % args[1:] + super().__init__(msg) + \ No newline at end of file diff --git a/src/pyramid/data/functional/application_info.py b/src/pyramid/data/functional/application_info.py index c8f33a0..366c831 100644 --- a/src/pyramid/data/functional/application_info.py +++ b/src/pyramid/data/functional/application_info.py @@ -9,7 +9,7 @@ class ApplicationInfo: def __init__(self): self.name = "pyramid" self.os = get_os().lower() - self.version = "0.3.2" + self.version = "0.3.3" self.git_info = GitInfo() def load_git_info(self): diff --git a/src/pyramid/data/functional/messages/message_sender.py b/src/pyramid/data/functional/messages/message_sender.py index ccb0116..7688745 100644 --- a/src/pyramid/data/functional/messages/message_sender.py +++ b/src/pyramid/data/functional/messages/message_sender.py @@ -1,11 +1,9 @@ import asyncio -import logging -from typing import Callable import tools.utils as tools -from discord import Interaction, Message, TextChannel, WebhookMessage -from discord.errors import HTTPException +from discord import Interaction, Message, WebhookMessage from discord.utils import MISSING +from discord.abc import Messageable MAX_MSG_LENGTH = 2000 @@ -15,100 +13,83 @@ def __init__(self, ctx: Interaction): self.ctx = ctx if ctx.channel is None: raise NotImplementedError("Unable to create a MessageSender without channel") - if not isinstance(ctx.channel, TextChannel): + if not isinstance(ctx.channel, Messageable): raise NotImplementedError("Unable to create a MessageSender without text channel") - self.txt_channel: TextChannel = ctx.channel - self.last_reponse: Message | WebhookMessage | None = None + self.txt_channel: Messageable = ctx.channel + self.last_message: Message | WebhookMessage | None = None + self.last_message_surname: dict[str, Message | WebhookMessage] = dict() self.loop: asyncio.AbstractEventLoop = ctx.client.loop + self.think = False + + async def thinking(self): + await self.ctx.response.defer(thinking=True) + self.think = True async def add_message( - # def add_message( self, - content: str = MISSING, - callback: Callable | None = None, - # ) -> None: - ) -> Message: + content: str = MISSING + ) -> Message | WebhookMessage: """ Add a message as a response or follow-up. If no message has been sent yet, the message is sent as a response. Otherwise, the message will be linked to the response (sent as a follow-up message). If the message exceeds the maximum character limit, it will be truncated. """ - if content != MISSING and content != "": - new_content, is_used = tools.substring_with_end_msg( - content, MAX_MSG_LENGTH, "{} more characters..." - ) - if is_used: - content = new_content + return await self.__add_message(content) + + async def __add_message( + self, + content: str = MISSING + ) -> Message | WebhookMessage: - if not self.ctx.response.is_done(): - msg = await self.txt_channel.send(content) + # If a reply has already been sent + if self.last_message: + last_message = await self._send_after_last_msg(content) + + # First reply else: - msg = await self.ctx.followup.send( - content, - wait=True, - ) - return msg + last_message = await self._send_as_first_reply(content) + + self.last_message = last_message + return last_message - async def response_message( - # def response_message( + async def edit_message( self, content: str = MISSING, - ): + surname_content: str | None = None, + ) -> Message | WebhookMessage: """ Send a message as a response. If the response has already been sent, it will be modified. - If it is not possible to modify it, a new message will be sent as a follow-up. If the message exceeds the maximum character limit, it will be truncated. """ - if content != MISSING and content != "": - new_content, is_used = tools.substring_with_end_msg( - content, MAX_MSG_LENGTH, "{} more characters..." - ) - if is_used: - content = new_content - if self.last_reponse is not None: - self.last_reponse.edit(content=content) + # If a reply has already been sent + if self.last_message: - elif self.ctx.response.is_done(): - try: - await self.ctx.edit_original_response( - content=content, - ) - # queue.add( - # QueueItem( - # "Edit response", - # self.__ctx.edit_original_response, - # self.loop, - # content=content, - # ) - # ) - except HTTPException as err: - if err.code == 50027: # 401 Unauthorized : Invalid Webhook Token - logging.warning( - "Unable to modify original response, send message instead", exc_info=True - ) - # self.last_reponse = await self.add_message(content) - await self.add_message(content, lambda msg: setattr(self, "last_response", msg)) + # If the message has a nickname, only the last reply with the same nickname is changed. + if surname_content: + last_message_edited = self.last_message_surname.get(surname_content) + + if last_message_edited: + msg = await last_message_edited.edit(content=self._tuncate_msg_if_overflow(content)) + self.last_message_surname[surname_content] = await last_message_edited.edit(content=self._tuncate_msg_if_overflow(content)) else: - raise err + msg = await self.__add_message(content) + self.last_message_surname[surname_content] = msg + else: + msg = await self.last_message.edit(content=self._tuncate_msg_if_overflow(content)) + + # First reply else: - await self.ctx.response.send_message( - content=content, - ) - # queue.add( - # QueueItem( - # "Send followup as response", - # self.__ctx.response.send_message, - # self.loop, - # content=content, - # ) - # ) + msg = await self._send_as_first_reply(content) + if surname_content: + self.last_message_surname[surname_content] = msg + + return msg async def add_code_message(self, content: str, prefix=None, suffix=None): """ Send a message with markdown code formatting. If the character limit is exceeded, send multiple messages. """ - # def add_code_message(self, content: str, prefix=None, suffix=None): max_length = MAX_MSG_LENGTH if prefix is None: prefix = "```" @@ -123,28 +104,79 @@ async def add_code_message(self, content: str, prefix=None, suffix=None): substrings_generator = tools.split_string_by_length(content, max_length) - if not self.ctx.response.is_done(): - first_substring = next(substrings_generator, None) - if first_substring is not None: - first_substring_formatted = f"```{first_substring}```" - await self.ctx.response.send_message(content=first_substring_formatted) - # queue.add( - # QueueItem( - # "Send code as response", - # self.__ctx.response.send_message, - # self.loop, - # content=first_substring_formatted, - # ) - # ) + first_substring = next(substrings_generator, None) + if first_substring is None: + return + + first_substring_formatted = f"```{first_substring}```" + if not self.ctx.is_expired() and self.ctx.response.is_done(): + self.last_message = await self.ctx.followup.send(first_substring_formatted, wait=True) + else: + self.last_message = await self.txt_channel.send(first_substring_formatted) + last_channel_message = None for substring in substrings_generator: substring_formatted = f"```{substring}```" - await self.ctx.followup.send(content=substring_formatted) - # queue.add( - # QueueItem( - # "Send code as followup", - # self.__ctx.followup.send, - # self.loop, - # content=substring_formatted, - # ) - # ) + + last_channel_message = await self._get_last_channel_message() + + # If the last message of channel is the last reply + if last_channel_message.id == self.last_message.id: + self.last_message = await self.txt_channel.send(content=substring_formatted) + else: + self.last_message = await self.ctx.followup.send( + content=substring_formatted, wait=True + ) + + async def _send_as_first_reply(self, content: str) -> Message | WebhookMessage: + # If interaction can be used + if not self.ctx.is_expired() and self.ctx.response.is_done(): + last_message = await self.ctx.followup.send( + self._tuncate_msg_if_overflow(content), wait=True + ) + else: + command_name = self.ctx.command.name if self.ctx.command else "" + last_message = await self.txt_channel.send( + self._tuncate_msg_if_overflow( + f"{self.ctx.user.mention} `/{command_name}` {content}" + ) + ) + self.last_message = last_message + return last_message + + async def _send_after_last_msg(self, content: str) -> Message | WebhookMessage: + if not self.last_message: + raise Exception("There is no last message") + + last_channel_message = await self._get_last_channel_message() + + # If the last message of channel is the last reply + if last_channel_message.id == self.last_message.id: + last_message = await self.txt_channel.send(self._tuncate_msg_if_overflow(content)) + + # If not, send a message linked to the last message + else: + last_message = await self.last_message.reply( + self._tuncate_msg_if_overflow(content) + ) + self.last_message = last_message + return last_message + + async def _get_last_channel_message(self) -> Message: + last_channel_message = None + history = self.txt_channel.history(limit=1) + async for message in history: + last_channel_message = message + if not last_channel_message: + raise Exception("Channel didn't have history") + return last_channel_message + + def _tuncate_msg_if_overflow(self, content: str) -> str: + if content == MISSING or content == "": + return content + new_content, is_used = tools.substring_with_end_msg( + content, MAX_MSG_LENGTH, "{} more characters..." + ) + if not is_used: + return content + return new_content diff --git a/src/pyramid/data/functional/messages/message_sender_queued.py b/src/pyramid/data/functional/messages/message_sender_queued.py index e147d97..ca8e366 100644 --- a/src/pyramid/data/functional/messages/message_sender_queued.py +++ b/src/pyramid/data/functional/messages/message_sender_queued.py @@ -1,9 +1,7 @@ -import logging -from typing import Callable +from typing import Any, Callable -import tools.utils as tools from data.functional.messages.message_sender import MessageSender -from discord import Interaction +from discord import Interaction, Message, WebhookMessage from discord.utils import MISSING from tools.queue import Queue, QueueItem @@ -19,73 +17,30 @@ def __init__(self, ctx: Interaction): self.ctx = ctx super().__init__(ctx) - async def waiting(self): - await super().response_message("Waiting for result ...") - def add_message( self, content: str = MISSING, - callback: Callable | None = None, + callback: Callable[[Message | WebhookMessage], Any] | None = None, ) -> None: queue.add( - QueueItem( - "add_message", super().add_message, self.loop, content=content, callback=callback - ) + QueueItem("add_message", super().add_message, self.loop, callback, content=content) ) - def response_message( + def edit_message( self, content: str = MISSING, + surname_content: str | None = None, + callback: Callable[[Message | WebhookMessage], Any] | None = None, ): - if content != MISSING and content != "": - new_content, is_used = tools.substring_with_end_msg( - content, MAX_MSG_LENGTH, "{} more characters..." - ) - if is_used: - content = new_content - - if self.last_reponse is not None: - queue.add( - QueueItem( - "Edit last response", - self.last_reponse.edit, - self.loop, - content=content, - ) - ) - - elif self.ctx.response.is_done(): - def on_error(err): - if err.code == 50027: # 401 Unauthorized : Invalid Webhook Token - logging.warning( - "Unable to modify original response, send message instead", exc_info=True - ) - self.add_message(content, lambda msg: setattr(self, "last_response", msg)) - else: - raise err - - queue.add( - QueueItem( - "Edit response", - self.ctx.edit_original_response, - self.loop, - None, - on_error, - content=content, - ) - ) - else: - queue.add( - QueueItem( - "Send followup as response", - self.ctx.response.send_message, - self.loop, - content=content, - ) - ) - queue.add( - QueueItem("response_message", super().response_message, self.loop, content=content) + QueueItem( + "response_message", + super().edit_message, + self.loop, + callback, + content=content, + surname_content=surname_content, + ) ) def add_code_message(self, content: str, prefix=None, suffix=None):