diff --git a/otter_welcome_buddy/cogs/hiring_timelines.py b/otter_welcome_buddy/cogs/hiring_timelines.py index fde6e3c..e42e8e4 100644 --- a/otter_welcome_buddy/cogs/hiring_timelines.py +++ b/otter_welcome_buddy/cogs/hiring_timelines.py @@ -11,6 +11,7 @@ from otter_welcome_buddy.common.constants import OTTER_ADMIN from otter_welcome_buddy.common.constants import OTTER_MODERATOR from otter_welcome_buddy.common.utils.dates import DateUtils +from otter_welcome_buddy.common.utils.discord_ import get_channel_by_id from otter_welcome_buddy.common.utils.discord_ import send_plain_message from otter_welcome_buddy.common.utils.types.common import DiscordChannelType from otter_welcome_buddy.database.handlers.db_announcements_config_handler import ( @@ -29,8 +30,9 @@ class Timelines(commands.Cog): """ Timelines command events, where notifications about hiring events are sent every month Commands: - timelines start: Start cronjob for timeline messages - timelines stop: Stop cronjob for timeline messages + timelines start: Start cronjob for timeline messages + timelines stop: Stop cronjob for timeline messages + timelines run send: Send announcement to configured channel """ def __init__(self, bot: Bot, messages_formatter: type[timeline.Formatter]): @@ -93,7 +95,6 @@ async def start( @timelines.command( # type: ignore brief="Remove the interview season announcements for a server", - usage="", ) @commands.has_any_role(OTTER_ADMIN, OTTER_MODERATOR) async def stop( @@ -131,17 +132,18 @@ async def _send_message_on_channel(self) -> None: """ for entry in DbAnnouncementsConfigHandler.get_all_announcements_configs(): try: - guild: discord.Guild = await self.bot.fetch_guild(entry.guild.id) - channel: DiscordChannelType = await guild.fetch_channel(entry.channel_id) - if not isinstance(channel, discord.TextChannel): - raise TypeError("Not valid channel to send the message in") + channel: DiscordChannelType | None = await get_channel_by_id( + self.bot, + entry.channel_id, + ) + if channel is None or not isinstance(channel, discord.TextChannel): + logger.error("Channel %s invalid to send the hiring message", entry.channel_id) + return await channel.send(self._get_hiring_events()) - except discord.NotFound: - logger.error("Fail getting channel %s in guild %s", entry.channel_id, guild.id) except discord.Forbidden: - logger.exception("Not enough permissions to fetch the data in %s", __name__) + logger.error("Not enough permissions to send the message in %s", __name__) except discord.HTTPException: - logger.error("Not guild found in %s", __name__) + logger.error("Sending the message failed in %s", __name__) except Exception: logger.exception("Error while sending the announcement in %s", __name__) diff --git a/otter_welcome_buddy/cogs/interview_match.py b/otter_welcome_buddy/cogs/interview_match.py index dbe5fa4..da901d6 100644 --- a/otter_welcome_buddy/cogs/interview_match.py +++ b/otter_welcome_buddy/cogs/interview_match.py @@ -13,6 +13,9 @@ from otter_welcome_buddy.common.constants import OTTER_ADMIN from otter_welcome_buddy.common.constants import OTTER_MODERATOR from otter_welcome_buddy.common.constants import OTTER_ROLE +from otter_welcome_buddy.common.utils.discord_ import get_channel_by_id +from otter_welcome_buddy.common.utils.discord_ import get_member_by_id +from otter_welcome_buddy.common.utils.discord_ import get_message_by_id from otter_welcome_buddy.common.utils.discord_ import send_plain_message from otter_welcome_buddy.common.utils.image import create_match_image from otter_welcome_buddy.common.utils.types.common import DiscordChannelType @@ -114,7 +117,10 @@ async def _send_weekly_message(self) -> None: role_to_mention=role.mention if role is not None else "", emoji=entry.emoji, ) - channel: DiscordChannelType | None = self.bot.get_channel(entry.channel_id) + channel: DiscordChannelType | None = await get_channel_by_id( + self.bot, + entry.channel_id, + ) if channel is None: logger.error("Fail getting the channel to send the weekly message") if not isinstance(channel, discord.TextChannel): @@ -223,22 +229,19 @@ async def _get_weekly_message( author_id: int, ) -> tuple[discord.TextChannel, discord.Message, discord.Member] | None: try: - channel: DiscordChannelType | None = self.bot.get_channel(channel_id) - if channel is None: - logger.error("No channel to check the weekly message") - return None - if not isinstance(channel, discord.TextChannel): - logger.warning("Not valid channel to send the message in") + message_data = await get_message_by_id(self.bot, channel_id, message_id) + if message_data is None: + logger.error("No message found to be checked") return None - cache_message = await channel.fetch_message(message_id) + message, channel = message_data - placeholder: discord.Member | None = channel.guild.get_member(author_id) + placeholder: discord.Member | None = await get_member_by_id(channel.guild, author_id) if placeholder is None: # TODO: add a fallback when no placeholder logger.error("No placeholder found for weekly check") return None - return channel, cache_message, placeholder + return channel, message, placeholder except discord.NotFound: logger.exception("No message found to be checked") diff --git a/otter_welcome_buddy/common/utils/discord_.py b/otter_welcome_buddy/common/utils/discord_.py index 45838b9..20b6e7d 100644 --- a/otter_welcome_buddy/common/utils/discord_.py +++ b/otter_welcome_buddy/common/utils/discord_.py @@ -1,17 +1,103 @@ import logging import discord +from discord.ext.commands import Bot from discord.ext.commands import Context +from otter_welcome_buddy.common.utils.types.common import DiscordChannelType + logger = logging.getLogger(__name__) +def get_basic_embed(title: str | None = None, description: str | None = None) -> discord.Embed: + """Get a basic embed""" + return discord.Embed(title=title, description=description, color=discord.Color.teal()) + + async def send_plain_message(ctx: Context, message: str) -> None: """Send a message as embed, this allows to use more markdown features""" try: - await ctx.send(embed=discord.Embed(description=message, color=discord.Color.teal())) + await ctx.send(embed=get_basic_embed(description=message)) except discord.Forbidden: logger.exception("Not enough permissions to send the message") except discord.HTTPException: logger.exception("Sending the message failed") + + +async def get_guild_by_id(bot: Bot, guild_id: int) -> discord.Guild | None: + """Get a guild by its id""" + # Check if the guild is in bot's cache + guild: discord.Guild | None = bot.get_guild(guild_id) + if guild is None: + try: + # Fetch the guild from Discord + guild = await bot.fetch_guild(guild_id) + except discord.Forbidden: + logger.error("Not enough permissions to fetch the guild %s", guild_id) + except discord.HTTPException: + logger.error("Getting the guild %s failed", guild_id) + + return guild + + +async def get_channel_by_id(bot: Bot, channel_id: int) -> DiscordChannelType | None: + """Get a channel by its id""" + # Check if the channel is in bot's cache + channel: DiscordChannelType | None = bot.get_channel(channel_id) + if channel is None: + try: + # Fetch the channel from Discord + channel = await bot.fetch_channel(channel_id) + except discord.NotFound: + logger.error("Invalid channel_id %s", channel_id) + except discord.InvalidData: + logger.error("Invalid channel type received for channel %s", channel_id) + except discord.Forbidden: + logger.error("Not enough permissions to fetch the channel %s", channel_id) + except discord.HTTPException: + logger.error("Getting the channel %s failed", channel_id) + + return channel + + +async def get_message_by_id( + bot: Bot, + channel_id: int, + message_id: int, +) -> tuple[discord.Message, discord.TextChannel] | None: + """Get a message by its id and its corresponding text channel""" + channel: DiscordChannelType | None = await get_channel_by_id(bot, channel_id) + if isinstance(channel, discord.TextChannel): + try: + # Fetch the message from Discord + message: discord.Message = await channel.fetch_message(message_id) + return (message, channel) + except discord.NotFound: + logger.error("Message with id %s not found", message_id) + except discord.Forbidden: + logger.error("Not enough permissions to fetch the message %s", message_id) + except discord.HTTPException: + logger.error("Getting the message %s failed", message_id) + else: + logger.error("Invalid channel %s while retrieving the message %s", channel_id, message_id) + + return None + + +async def get_member_by_id(guild: discord.Guild, member_id: int) -> discord.Member | None: + """Get a member by its id""" + # Check if the member is in guild's cache + member: discord.Member | None = guild.get_member(member_id) + if member is None: + try: + # Fetch the member from Discord + member = await guild.fetch_member(member_id) + except discord.NotFound: + logger.error("Member with id %s not found", member_id) + except discord.Forbidden: + logger.error("Not enough permissions to fetch the member %s", member_id) + except discord.HTTPException: + logger.error("Getting the member %s failed", member_id) + + return member diff --git a/tests/conftest.py b/tests/conftest.py index 6f817f5..953335d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,14 @@ import os from collections.abc import Callable -from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import Mock import pytest from discord import Guild from discord import Member +from discord import Message from discord import Role +from discord import TextChannel from discord.ext.commands import Bot from discord.ext.commands import Context from mongoengine import connect as mongo_connect @@ -19,13 +20,13 @@ @pytest.fixture def mock_ctx() -> Context: - mocked_ctx = AsyncMock() + mocked_ctx = Mock() return mocked_ctx @pytest.fixture def mock_bot() -> Bot: - mocked_bot = AsyncMock() + mocked_bot = Mock() return mocked_bot @@ -59,6 +60,18 @@ def mock_role() -> Role: return mocked_role +@pytest.fixture +def mock_text_channel() -> TextChannel: + mocked_text_channel = Mock() + return mocked_text_channel + + +@pytest.fixture +def mock_message() -> Message: + mocked_message = Mock() + return mocked_message + + @pytest.fixture def mock_msg_fmt(): mocked_msg_fmt = MagicMock() diff --git a/tests/utils/test_discord_.py b/tests/utils/test_discord_.py index 930233a..bb7410c 100644 --- a/tests/utils/test_discord_.py +++ b/tests/utils/test_discord_.py @@ -1,5 +1,8 @@ +from unittest.mock import AsyncMock + import discord import pytest +from discord.ext.commands import Bot from discord.ext.commands import Context from pytest_mock import MockFixture @@ -7,16 +10,161 @@ @pytest.mark.asyncio -async def test_send_plain_message(mocker: MockFixture, mock_ctx: Context) -> None: +async def test_send_plain_message(mock_ctx: Context) -> None: # Arrange test_msg: str = "Test message" - mock_ctx_send = mocker.patch.object(mock_ctx, "send") + mock_ctx.send = AsyncMock() # Act await discord_.send_plain_message(mock_ctx, test_msg) # Assert - mock_ctx_send.assert_called_once_with( + mock_ctx.send.assert_called_once_with( embed=discord.Embed(description=test_msg, color=discord.Color.teal()), ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("is_in_cache", [True, False]) +async def test_getGuildById_Succeed( + mocker: MockFixture, + mock_bot: Bot, + mock_guild: discord.Guild, + is_in_cache: bool, +) -> None: + # Arrange + mocked_guild_id: int = 111 + mock_guild.id = mocked_guild_id + + mock_get_guild = mocker.patch.object( + mock_bot, + "get_guild", + return_value=mock_guild if is_in_cache else None, + ) + mock_fetch_guild = mocker.patch.object( + mock_bot, + "fetch_guild", + new=AsyncMock(return_value=mock_guild), + ) + + # Act + result = await discord_.get_guild_by_id(mock_bot, mocked_guild_id) + + # Assert + mock_get_guild.assert_called_once_with(mocked_guild_id) + if is_in_cache: + mock_fetch_guild.assert_not_called() + else: + mock_fetch_guild.assert_called_once_with(mocked_guild_id) + assert result == mock_guild + + +@pytest.mark.asyncio +@pytest.mark.parametrize("is_in_cache", [True, False]) +async def test_getChannelById_Succeed( + mocker: MockFixture, + mock_bot: Bot, + mock_text_channel: discord.TextChannel, + is_in_cache: bool, +) -> None: + # Arrange + mocked_channel_id: int = 111 + mock_text_channel.id = mocked_channel_id + + mock_get_channel = mocker.patch.object( + mock_bot, + "get_channel", + return_value=mock_text_channel if is_in_cache else None, + ) + mock_fetch_channel = mocker.patch.object( + mock_bot, + "fetch_channel", + new=AsyncMock(return_value=mock_text_channel), + ) + + # Act + result = await discord_.get_channel_by_id(mock_bot, mocked_channel_id) + + # Assert + mock_get_channel.assert_called_once_with(mocked_channel_id) + if is_in_cache: + mock_fetch_channel.assert_not_called() + else: + mock_fetch_channel.assert_called_once_with(mocked_channel_id) + assert result == mock_text_channel + + +@pytest.mark.asyncio +@pytest.mark.parametrize("channel_exists", [True, False]) +async def test_getMessageById_Succeed( + mocker: MockFixture, + mock_bot: Bot, + mock_text_channel: discord.TextChannel, + mock_message: discord.Message, + channel_exists: bool, +) -> None: + # Arrange + mocked_channel_id: int = 111 + mock_text_channel.id = mocked_channel_id + mocked_message_id: int = 222 + mock_message.id = mocked_message_id + + mocker.patch.object(discord_, "isinstance", return_value=channel_exists) + mock_get_channel_by_id = mocker.patch.object( + discord_, + "get_channel_by_id", + new=AsyncMock(return_value=mock_text_channel if channel_exists else None), + ) + mock_fetch_message = mocker.patch.object( + mock_text_channel, + "fetch_message", + new=AsyncMock(return_value=mock_message), + ) + + # Act + result = await discord_.get_message_by_id(mock_bot, mocked_channel_id, mocked_message_id) + + # Assert + mock_get_channel_by_id.assert_called_once_with(mock_bot, mocked_channel_id) + if channel_exists: + mock_fetch_message.assert_called_once_with(mocked_message_id) + assert result == (mock_message, mock_text_channel) + else: + mock_fetch_message.assert_not_called() + assert result is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("is_in_cache", [True, False]) +async def test_getMemberById_Succeed( + mocker: MockFixture, + mock_guild: discord.Guild, + mock_member: discord.Member, + is_in_cache: bool, +) -> None: + # Arrange + mocked_member_id: int = 111 + mock_member.id = mocked_member_id + + mock_get_member = mocker.patch.object( + mock_guild, + "get_member", + return_value=mock_member if is_in_cache else None, + ) + mock_fetch_member = mocker.patch.object( + mock_guild, + "fetch_member", + new=AsyncMock(return_value=mock_member), + ) + + # Act + result = await discord_.get_member_by_id(mock_guild, mocked_member_id) + + # Assert + mock_get_member.assert_called_once_with(mocked_member_id) + if is_in_cache: + mock_fetch_member.assert_not_called() + else: + mock_fetch_member.assert_called_once_with(mocked_member_id) + assert result == mock_member