From feea032fabac89b02a8ab35bd0fa60c42ef99407 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Sat, 7 Sep 2024 23:32:43 +0200 Subject: [PATCH 01/32] wip: split commands into multiple files --- src/pyramid/connector/discord/bot.py | 2 +- src/pyramid/connector/discord/bot_cmd.py | 104 +++--------------- .../discord/commands/about_command.py | 75 +++++++++++++ .../discord/commands/abstract_command.py | 61 ++++++++++ .../discord/commands/help_command.py | 32 ++++++ .../discord/commands/ping_command.py | 15 +++ src/pyramid/tools/utils.py | 11 +- 7 files changed, 207 insertions(+), 93 deletions(-) create mode 100644 src/pyramid/connector/discord/commands/about_command.py create mode 100644 src/pyramid/connector/discord/commands/abstract_command.py create mode 100644 src/pyramid/connector/discord/commands/help_command.py create mode 100644 src/pyramid/connector/discord/commands/ping_command.py diff --git a/src/pyramid/connector/discord/bot.py b/src/pyramid/connector/discord/bot.py index 6d62e52..352d26e 100644 --- a/src/pyramid/connector/discord/bot.py +++ b/src/pyramid/connector/discord/bot.py @@ -134,8 +134,8 @@ async def on_tree_error(ctx: Interaction, app_error: AppCommandError, /): async def on_command(ctx: Context): logging.debug("on_command : %s", ctx.author) - self.listeners.register() self.cmd.register() + self.listeners.register() async def start(self): try: diff --git a/src/pyramid/connector/discord/bot_cmd.py b/src/pyramid/connector/discord/bot_cmd.py index d859cec..b54e1cb 100644 --- a/src/pyramid/connector/discord/bot_cmd.py +++ b/src/pyramid/connector/discord/bot_cmd.py @@ -1,18 +1,18 @@ import math from logging import Logger import time -from typing import Callable, List +from typing import Callable -from discord import AppInfo, ClientUser, Color, Embed, Guild, Interaction -from discord.app_commands import Command +from discord import Guild, Interaction from discord.ext.commands import Bot -from discord.user import BaseUser +from pyramid.connector.discord.commands.about_command import AboutCommand +from pyramid.connector.discord.commands.help_command import HelpCommand +from pyramid.connector.discord.commands.ping_command import PingCommand from pyramid.connector.discord.guild_cmd import GuildCmd from pyramid.data.environment import Environment from pyramid.data.functional.application_info import ApplicationInfo from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued from pyramid.data.functional.engine_source import SourceType -import pyramid.tools.utils as tools class BotCmd: @@ -35,92 +35,14 @@ def __init__( def register(self): bot = self.__bot - @bot.tree.command( - name="ping", description="Shows the response time between the bot and Discord API" - ) - async def cmd_ping(ctx: Interaction): - await ctx.response.defer(thinking=True) - await ctx.followup.send(f"Pong ! ({math.trunc(bot.latency * 1000)}ms)") - - @bot.tree.command(name="about", description="Information about the bot") - async def cmd_about(ctx: Interaction): - await ctx.response.defer(thinking=True) - bot_user: ClientUser | None - if bot.user is not None: - bot_user = bot.user - else: - bot_user = None - self.__logger.warning("Unable to get self user instance") - - info = self.__info - embed = Embed(title=info.get_name(), color=Color.gold()) - if bot_user is not None and bot_user.avatar is not None: - embed.set_thumbnail(url=bot_user.avatar.url) - - owner_id: int | None = bot.owner_id - if owner_id is None and bot.owner_ids is not None and len(bot.owner_ids) > 0: - owner_id = next(iter(bot.owner_ids)) - else: - owner_id = None - - owner: BaseUser | None - if owner_id is not None: - owner = await bot.fetch_user(owner_id) - else: - owner = None - - if owner is None: - t: AppInfo = await bot.application_info() - if t.team is not None: - team = t.team - if team.owner is not None: - owner = team.owner - - if owner is not None: - embed.set_footer( - text=f"Owned by {owner.display_name}", - icon_url=owner.avatar.url if owner.avatar is not None else None, - ) - - embed.add_field(name="Version", value=info.get_version(), inline=True) - embed.add_field(name="OS", value=info.get_os(), inline=True) - embed.add_field( - name="Environment", - value=self.__environment.name.capitalize(), - inline=True, - ) - embed.add_field( - name="Uptime", - value=tools.time_to_duration(int(round(time.time() - self.__started))), - inline=True, - ) - - await ctx.followup.send(embed=embed) - - @bot.tree.command(name="help", description="List all commands") - async def cmd_help(ctx: Interaction): - await ctx.response.defer(thinking=True) - all_commands: List[Command] = bot.tree.get_commands() # type: ignore - commands_dict = {command.name: command.description for command in all_commands} - embed_template = Embed(title="List of every commands available", color=Color.gold()) - max_embed = 10 - max_fields = 25 - embeds = [] - - for command, description in commands_dict.items(): - embed_template.add_field(name=command, value=description, inline=True) - if len(embed_template.fields) == max_fields: - embeds.append(embeds) - embed_template.clear_fields() - - # Append the last embed if it's not empty - if len(embed_template.fields) > 0: - embeds.append(embed_template) - - # Sending the first embed as a response and subsequent follow-up embeds - for i in range(0, len(embeds), max_embed): - embeds_chunk = embeds[i : i + max_embed] - await ctx.followup.send(embeds=embeds_chunk) + ping = PingCommand(self.__bot, self.__logger) + ping.register(self.__environment.name.lower()) + + about = AboutCommand(self.__bot, self.__logger, self.__started, self.__environment, self.__info) + about.register(self.__environment.name.lower()) + + help = HelpCommand(self.__bot, self.__logger) + help.register(self.__environment.name.lower()) @bot.tree.command(name="play", description="Adds a track to the end of the queue and plays it") async def cmd_play(ctx: Interaction, input: str, engine: SourceType | None): diff --git a/src/pyramid/connector/discord/commands/about_command.py b/src/pyramid/connector/discord/commands/about_command.py new file mode 100644 index 0000000..b35733d --- /dev/null +++ b/src/pyramid/connector/discord/commands/about_command.py @@ -0,0 +1,75 @@ +import logging +import time +from typing import Union +from discord import AppInfo, ClientUser, Color, Embed, Interaction +from discord.ext.commands import Bot +from discord.user import BaseUser +from discord.app_commands import locale_str +from pyramid.connector.discord.commands.abstract_command import AbstractCommand +from pyramid.data.environment import Environment +from pyramid.data.functional.application_info import ApplicationInfo +import pyramid.tools.utils as tools + +class AboutCommand(AbstractCommand): + def __init__(self, bot: Bot, logger: logging.Logger, started: float, environment: Environment, info: ApplicationInfo): + super().__init__(bot, logger) + self.__started = started + self.__environment = environment + self.__info = info + + def description(self) -> Union[str, locale_str]: + return "About the bot" + + async def execute(self, ctx: Interaction): + await ctx.response.defer(thinking=True) + bot_user: ClientUser | None + if self.bot.user is not None: + bot_user = self.bot.user + else: + bot_user = None + self.logger.warning("Unable to get self user instance") + + info = self.__info + embed = Embed(title=info.get_name(), color=Color.gold()) + if bot_user is not None and bot_user.avatar is not None: + embed.set_thumbnail(url=bot_user.avatar.url) + + owner_id: int | None = self.bot.owner_id + if owner_id is None and self.bot.owner_ids is not None and len(self.bot.owner_ids) > 0: + owner_id = next(iter(self.bot.owner_ids)) + else: + owner_id = None + + owner: BaseUser | None + if owner_id is not None: + owner = await self.bot.fetch_user(owner_id) + else: + owner = None + + if owner is None: + t: AppInfo = await self.bot.application_info() + if t.team is not None: + team = t.team + if team.owner is not None: + owner = team.owner + + if owner is not None: + embed.set_footer( + text=f"Owned by {owner.display_name}", + icon_url=owner.avatar.url if owner.avatar is not None else None, + ) + + embed.add_field(name="Version", value=info.get_version(), inline=True) + embed.add_field(name="OS", value=info.get_os(), inline=True) + embed.add_field( + name="Environment", + value=self.__environment.name.capitalize(), + inline=True, + ) + embed.add_field( + name="Uptime", + value=tools.time_to_duration(int(round(time.time() - self.__started))), + inline=True, + ) + + await ctx.followup.send(embed=embed) diff --git a/src/pyramid/connector/discord/commands/abstract_command.py b/src/pyramid/connector/discord/commands/abstract_command.py new file mode 100644 index 0000000..bd04de3 --- /dev/null +++ b/src/pyramid/connector/discord/commands/abstract_command.py @@ -0,0 +1,61 @@ +from abc import ABC, abstractmethod +import logging +from typing import Any, Dict, Optional, Union + +from discord import Interaction +from discord.app_commands import Group, Command, locale_str +from discord.ext.commands import Bot +from discord.utils import MISSING + +import pyramid.tools.utils as tools + +class AbstractCommand(ABC): + + def __init__(self, bot: Bot, logger: logging.Logger): + self.bot = bot + self.logger = logger + + def name(self) -> Union[str, locale_str]: + class_name = self.__class__.__name__ + + if class_name.endswith("Command"): + command_name = class_name[:-len("Command")] + else: + command_name = class_name + command_name = tools.camel_to_snake(command_name) + return command_name + + def description(self) -> Union[str, locale_str]: + return "" + + def nsfw(self) -> bool: + return False + + def parent(self) -> Optional[Group]: + return None + + def auto_locale_strings(self) -> bool: + return True + + def extras(self) -> Dict[Any, Any]: + return MISSING + + @abstractmethod + async def execute(self, ctx: Interaction): + pass + + def register(self, command_prefix: Optional[str] = None): + if command_prefix is not None: + command_name = "%s_%s" % (command_prefix, self.name()) + else: + command_name = self.name() + command = Command( + name=command_name, + description=self.description(), + callback=self.execute, + nsfw=self.nsfw(), + parent=self.parent(), + auto_locale_strings=self.auto_locale_strings(), + extras=self.extras(), + ) + self.bot.tree.add_command(command, guild=MISSING, guilds=MISSING, override=False) diff --git a/src/pyramid/connector/discord/commands/help_command.py b/src/pyramid/connector/discord/commands/help_command.py new file mode 100644 index 0000000..6c487b0 --- /dev/null +++ b/src/pyramid/connector/discord/commands/help_command.py @@ -0,0 +1,32 @@ +from typing import List, Union +from discord import Color, Embed, Interaction +from discord.app_commands import Command, locale_str +from pyramid.connector.discord.commands.abstract_command import AbstractCommand + +class HelpCommand(AbstractCommand): + + def description(self) -> Union[str, locale_str]: + return "List all commands" + + async def execute(self, ctx: Interaction): + await ctx.response.defer(thinking=True) + all_commands: List[Command] = self.bot.tree.get_commands() # type: ignore + commands_dict = {command.name: command.description for command in all_commands} + embed_template = Embed(title="List of every commands available", color=Color.gold()) + max_embed = 10 + max_fields = 25 + embeds = [] + + for command, description in commands_dict.items(): + embed_template.add_field(name=command, value=description, inline=True) + if len(embed_template.fields) == max_fields: + embeds.append(embeds) + embed_template.clear_fields() + + # Append the last embed if it's not empty + if len(embed_template.fields) > 0: + embeds.append(embed_template) + + for i in range(0, len(embeds), max_embed): + embeds_chunk = embeds[i : i + max_embed] + await ctx.followup.send(embeds=embeds_chunk) diff --git a/src/pyramid/connector/discord/commands/ping_command.py b/src/pyramid/connector/discord/commands/ping_command.py new file mode 100644 index 0000000..2a7e57c --- /dev/null +++ b/src/pyramid/connector/discord/commands/ping_command.py @@ -0,0 +1,15 @@ +import math +from typing import Union +from discord import Interaction +from discord.app_commands import locale_str +from pyramid.connector.discord.commands.abstract_command import AbstractCommand + +class PingCommand(AbstractCommand): + + def description(self) -> Union[str, locale_str]: + return "Displays response time between bot and Discord API" + + async def execute(self, ctx: Interaction): + await ctx.response.defer(thinking=True) + latency = math.trunc(self.bot.latency * 1000) + await ctx.followup.send("Pong ! (%dms)" % latency) diff --git a/src/pyramid/tools/utils.py b/src/pyramid/tools/utils.py index b1b661e..ef01b2f 100644 --- a/src/pyramid/tools/utils.py +++ b/src/pyramid/tools/utils.py @@ -293,7 +293,7 @@ def count_public_variables(obj): return len(public_variables) -def format_bytes_speed(bytes_per_second): +def format_bytes_speed(bytes_per_second: float): units = ["bps", "Kbps", "Mbps", "Gbps", "Tbps"] factor = 1000 for unit in units: @@ -302,3 +302,12 @@ def format_bytes_speed(bytes_per_second): bytes_per_second /= factor return f"{round(bytes_per_second)} {units[-1]}" + +def camel_to_snake(name: str): + snake_case = "" + for i, char in enumerate(name): + if char.isupper() and i != 0: + snake_case += "_" + char.lower() + else: + snake_case += char.lower() + return snake_case From 93927dbb5b4d189ab1a79958bf53418282103b1e Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Mon, 9 Sep 2024 01:08:55 +0200 Subject: [PATCH 02/32] feat: register cmd using decorator --- src/pyramid/connector/deezer/cli_deezer.py | 5 +- src/pyramid/connector/discord/bot_cmd.py | 16 +++-- .../discord/commands/about_command.py | 27 ++++---- .../discord/commands/abstract_command.py | 61 ------------------- .../discord/commands/api/abstract_command.py | 37 +++++++++++ .../commands/api/annotation_command.py | 54 ++++++++++++++++ .../commands/api/parameters_command.py | 23 +++++++ .../discord/commands/api/register_command.py | 14 +++++ .../discord/commands/help_command.py | 11 ++-- .../discord/commands/ping_command.py | 10 ++- src/pyramid/data/a_engine_tools.py | 5 +- src/pyramid/data/a_search.py | 23 ++++--- src/pyramid/data/functional/main.py | 6 +- .../functional/messages/message_sender.py | 6 +- src/pyramid/data/tracklist.py | 6 +- src/pyramid/tools/logs_handler.py | 6 +- src/pyramid/tools/utils.py | 9 --- 17 files changed, 185 insertions(+), 134 deletions(-) delete mode 100644 src/pyramid/connector/discord/commands/abstract_command.py create mode 100644 src/pyramid/connector/discord/commands/api/abstract_command.py create mode 100644 src/pyramid/connector/discord/commands/api/annotation_command.py create mode 100644 src/pyramid/connector/discord/commands/api/parameters_command.py create mode 100644 src/pyramid/connector/discord/commands/api/register_command.py diff --git a/src/pyramid/connector/deezer/cli_deezer.py b/src/pyramid/connector/deezer/cli_deezer.py index 3af2a2f..500f27d 100644 --- a/src/pyramid/connector/deezer/cli_deezer.py +++ b/src/pyramid/connector/deezer/cli_deezer.py @@ -1,7 +1,6 @@ -import abc import asyncio import time -from abc import ABC +from abc import ABC, abstractmethod from typing import Any, Generic, Literal, Self from urllib.parse import parse_qs, urlparse @@ -41,7 +40,7 @@ async def add(self): class ACliDeezer(ABC): - @abc.abstractmethod + @abstractmethod async def async_request( self, method: str, diff --git a/src/pyramid/connector/discord/bot_cmd.py b/src/pyramid/connector/discord/bot_cmd.py index b54e1cb..c6ecc83 100644 --- a/src/pyramid/connector/discord/bot_cmd.py +++ b/src/pyramid/connector/discord/bot_cmd.py @@ -6,6 +6,8 @@ from discord import Guild, Interaction from discord.ext.commands import Bot from pyramid.connector.discord.commands.about_command import AboutCommand +from pyramid.connector.discord.commands.api.parameters_command import ParametersCommand +from pyramid.connector.discord.commands.api.register_command import register_commands from pyramid.connector.discord.commands.help_command import HelpCommand from pyramid.connector.discord.commands.ping_command import PingCommand from pyramid.connector.discord.guild_cmd import GuildCmd @@ -35,14 +37,16 @@ def __init__( def register(self): bot = self.__bot - ping = PingCommand(self.__bot, self.__logger) - ping.register(self.__environment.name.lower()) + register_commands(self.__bot, self.__logger, self.__environment.name.lower()) - about = AboutCommand(self.__bot, self.__logger, self.__started, self.__environment, self.__info) - about.register(self.__environment.name.lower()) + # ping = PingCommand(ParametersCommand("ping"), self.__bot, self.__logger) + # ping.register(self.__environment.name.lower()) - help = HelpCommand(self.__bot, self.__logger) - help.register(self.__environment.name.lower()) + # about = AboutCommand(self.__bot, self.__logger, self.__started, self.__environment, self.__info) + # about.register(self.__environment.name.lower()) + + # help = HelpCommand(self.__bot, self.__logger) + # help.register(self.__environment.name.lower()) @bot.tree.command(name="play", description="Adds a track to the end of the queue and plays it") async def cmd_play(ctx: Interaction, input: str, engine: SourceType | None): diff --git a/src/pyramid/connector/discord/commands/about_command.py b/src/pyramid/connector/discord/commands/about_command.py index b35733d..dd67227 100644 --- a/src/pyramid/connector/discord/commands/about_command.py +++ b/src/pyramid/connector/discord/commands/about_command.py @@ -1,24 +1,19 @@ -import logging import time -from typing import Union from discord import AppInfo, ClientUser, Color, Embed, Interaction -from discord.ext.commands import Bot from discord.user import BaseUser -from discord.app_commands import locale_str -from pyramid.connector.discord.commands.abstract_command import AbstractCommand -from pyramid.data.environment import Environment -from pyramid.data.functional.application_info import ApplicationInfo -import pyramid.tools.utils as tools +from pyramid.connector.discord.commands.api.abstract_command import AbstractCommand +from pyramid.connector.discord.commands.api.annotation_command import discord_command +from pyramid.connector.discord.commands.api.parameters_command import ParametersCommand +from pyramid.tools import utils +@discord_command(parameters=ParametersCommand(description="About the bot")) class AboutCommand(AbstractCommand): - def __init__(self, bot: Bot, logger: logging.Logger, started: float, environment: Environment, info: ApplicationInfo): - super().__init__(bot, logger) - self.__started = started - self.__environment = environment - self.__info = info - def description(self) -> Union[str, locale_str]: - return "About the bot" + # def __init__(self, bot: Bot, logger: logging.Logger, started: float, environment: Environment, info: ApplicationInfo): + # super().__init__(bot, logger) + # self.__started = started + # self.__environment = environment + # self.__info = info async def execute(self, ctx: Interaction): await ctx.response.defer(thinking=True) @@ -68,7 +63,7 @@ async def execute(self, ctx: Interaction): ) embed.add_field( name="Uptime", - value=tools.time_to_duration(int(round(time.time() - self.__started))), + value=utils.time_to_duration(int(round(time.time() - self.__started))), inline=True, ) diff --git a/src/pyramid/connector/discord/commands/abstract_command.py b/src/pyramid/connector/discord/commands/abstract_command.py deleted file mode 100644 index bd04de3..0000000 --- a/src/pyramid/connector/discord/commands/abstract_command.py +++ /dev/null @@ -1,61 +0,0 @@ -from abc import ABC, abstractmethod -import logging -from typing import Any, Dict, Optional, Union - -from discord import Interaction -from discord.app_commands import Group, Command, locale_str -from discord.ext.commands import Bot -from discord.utils import MISSING - -import pyramid.tools.utils as tools - -class AbstractCommand(ABC): - - def __init__(self, bot: Bot, logger: logging.Logger): - self.bot = bot - self.logger = logger - - def name(self) -> Union[str, locale_str]: - class_name = self.__class__.__name__ - - if class_name.endswith("Command"): - command_name = class_name[:-len("Command")] - else: - command_name = class_name - command_name = tools.camel_to_snake(command_name) - return command_name - - def description(self) -> Union[str, locale_str]: - return "" - - def nsfw(self) -> bool: - return False - - def parent(self) -> Optional[Group]: - return None - - def auto_locale_strings(self) -> bool: - return True - - def extras(self) -> Dict[Any, Any]: - return MISSING - - @abstractmethod - async def execute(self, ctx: Interaction): - pass - - def register(self, command_prefix: Optional[str] = None): - if command_prefix is not None: - command_name = "%s_%s" % (command_prefix, self.name()) - else: - command_name = self.name() - command = Command( - name=command_name, - description=self.description(), - callback=self.execute, - nsfw=self.nsfw(), - parent=self.parent(), - auto_locale_strings=self.auto_locale_strings(), - extras=self.extras(), - ) - self.bot.tree.add_command(command, guild=MISSING, guilds=MISSING, override=False) diff --git a/src/pyramid/connector/discord/commands/api/abstract_command.py b/src/pyramid/connector/discord/commands/api/abstract_command.py new file mode 100644 index 0000000..02453ce --- /dev/null +++ b/src/pyramid/connector/discord/commands/api/abstract_command.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod +import copy +import logging +from typing import Optional +from discord import Interaction +from discord.app_commands import Command +from discord.ext.commands import Bot +from discord.utils import MISSING + +from pyramid.connector.discord.commands.api.parameters_command import ParametersCommand + +class AbstractCommand(ABC): + + def __init__(self, parameters: ParametersCommand, bot: Bot, logger: logging.Logger): + self.bot = bot + self.logger = logger + self.parameters = parameters + + @abstractmethod + async def execute(self, ctx: Interaction): + pass + + def register(self, command_prefix: Optional[str] = None): + if command_prefix is not None: + self.parameters.name = "%s_%s" % (command_prefix, self.parameters.name) + + command = Command( + name=self.parameters.name, + description=self.parameters.description, + callback=self.execute, + nsfw=self.parameters.nsfw, + parent=None, + auto_locale_strings=self.parameters.auto_locale_strings, + extras=self.parameters.extras, + ) + # self.bot.tree.add_command(command, guilds=self.parameters.guilds) + self.bot.tree.add_command(command) diff --git a/src/pyramid/connector/discord/commands/api/annotation_command.py b/src/pyramid/connector/discord/commands/api/annotation_command.py new file mode 100644 index 0000000..eecc152 --- /dev/null +++ b/src/pyramid/connector/discord/commands/api/annotation_command.py @@ -0,0 +1,54 @@ +import re +from textwrap import TextWrapper +from discord.utils import MISSING + +from pyramid.connector.discord.commands.api.abstract_command import AbstractCommand +from pyramid.connector.discord.commands.api.parameters_command import ParametersCommand +from pyramid.connector.discord.commands.api.register_command import COMMANDS_AUTOREGISTRED + + +def discord_command(*, parameters: ParametersCommand): + def decorator(cls): + if not issubclass(cls, AbstractCommand): + raise TypeError(f"Class {cls.__name__} must extend from AbstractCommand") + + if parameters.name is MISSING: + class_name = cls.__name__ + if class_name.endswith("Command"): + class_name = class_name[:-len("Command")] + parameters.name = _camel_to_snake(class_name) + + if parameters.description is MISSING: + if cls.__doc__ is None: + parameters.description = '…' + else: + parameters.description = _shorten(cls.__doc__) + COMMANDS_AUTOREGISTRED[cls] = parameters + return cls + return decorator + + +def _camel_to_snake(name: str): + snake_case = "" + for i, char in enumerate(name): + if char.isupper() and i != 0: + snake_case += "_" + char.lower() + else: + snake_case += char.lower() + return snake_case + +def _shorten( + input: str, + *, + _wrapper: TextWrapper = TextWrapper(width=100, max_lines=1, replace_whitespace=True, placeholder='…') +) -> str: + """ + Copy of func :func:`discord.utils._shorten`. + """ + + try: + # split on the first double newline since arguments may appear after that + input, _ = re.split(r'\n\s*\n', input, maxsplit=1) + except ValueError: + pass + return _wrapper.fill(' '.join(input.strip().split())) diff --git a/src/pyramid/connector/discord/commands/api/parameters_command.py b/src/pyramid/connector/discord/commands/api/parameters_command.py new file mode 100644 index 0000000..0d1f2f0 --- /dev/null +++ b/src/pyramid/connector/discord/commands/api/parameters_command.py @@ -0,0 +1,23 @@ +from typing import Any, Dict, Optional, Sequence, Union + +from discord.abc import Snowflake +from discord.app_commands import locale_str +from discord.utils import MISSING + +class ParametersCommand: + def __init__(self, + name: Union[str, locale_str] = MISSING, + description: Union[str, locale_str] = MISSING, + nsfw: bool = False, + guild: Optional[Snowflake] = MISSING, + guilds: Sequence[Snowflake] = MISSING, + auto_locale_strings: bool = True, + extras: Dict[Any, Any] = MISSING + ): + self.name = name + self.description = description + self.nsfw = nsfw + self.guild = guild + self.guilds = guilds + self.auto_locale_strings = auto_locale_strings + self.extras: Dict[Any, Any] = MISSING diff --git a/src/pyramid/connector/discord/commands/api/register_command.py b/src/pyramid/connector/discord/commands/api/register_command.py new file mode 100644 index 0000000..2a3f6b1 --- /dev/null +++ b/src/pyramid/connector/discord/commands/api/register_command.py @@ -0,0 +1,14 @@ + +import logging +from discord.ext.commands import Bot +from pyramid.connector.discord.commands.api.abstract_command import AbstractCommand +from pyramid.connector.discord.commands.api.parameters_command import ParametersCommand + +COMMANDS_AUTOREGISTRED: dict[type[AbstractCommand], ParametersCommand] = {} + + +def register_commands(bot: Bot, logger: logging.Logger, command_prefix: str | None = None): + for cls, parameters in COMMANDS_AUTOREGISTRED.items(): + class_instance = cls(parameters, bot, logger) + class_instance.register(command_prefix) + logger.info("%s - %s" % (vars(cls), vars(parameters))) diff --git a/src/pyramid/connector/discord/commands/help_command.py b/src/pyramid/connector/discord/commands/help_command.py index 6c487b0..e9f9235 100644 --- a/src/pyramid/connector/discord/commands/help_command.py +++ b/src/pyramid/connector/discord/commands/help_command.py @@ -1,13 +1,12 @@ -from typing import List, Union +from typing import List from discord import Color, Embed, Interaction -from discord.app_commands import Command, locale_str -from pyramid.connector.discord.commands.abstract_command import AbstractCommand +from discord.app_commands import Command +from pyramid.connector.discord.commands.api.abstract_command import AbstractCommand, ParametersCommand +from pyramid.connector.discord.commands.api.annotation_command import discord_command +@discord_command(parameters=ParametersCommand(description="List all commands")) class HelpCommand(AbstractCommand): - def description(self) -> Union[str, locale_str]: - return "List all commands" - async def execute(self, ctx: Interaction): await ctx.response.defer(thinking=True) all_commands: List[Command] = self.bot.tree.get_commands() # type: ignore diff --git a/src/pyramid/connector/discord/commands/ping_command.py b/src/pyramid/connector/discord/commands/ping_command.py index 2a7e57c..6480bd6 100644 --- a/src/pyramid/connector/discord/commands/ping_command.py +++ b/src/pyramid/connector/discord/commands/ping_command.py @@ -1,14 +1,12 @@ import math -from typing import Union from discord import Interaction -from discord.app_commands import locale_str -from pyramid.connector.discord.commands.abstract_command import AbstractCommand +from pyramid.connector.discord.commands.api.abstract_command import AbstractCommand +from pyramid.connector.discord.commands.api.annotation_command import discord_command +from pyramid.connector.discord.commands.api.parameters_command import ParametersCommand +@discord_command(parameters=ParametersCommand(description="Displays response time between bot and Discord API")) class PingCommand(AbstractCommand): - def description(self) -> Union[str, locale_str]: - return "Displays response time between bot and Discord API" - async def execute(self, ctx: Interaction): await ctx.response.defer(thinking=True) latency = math.trunc(self.bot.latency * 1000) diff --git a/src/pyramid/data/a_engine_tools.py b/src/pyramid/data/a_engine_tools.py index 2ec67ec..5daf380 100644 --- a/src/pyramid/data/a_engine_tools.py +++ b/src/pyramid/data/a_engine_tools.py @@ -1,9 +1,8 @@ -import abc -from abc import ABC +from abc import ABC, abstractmethod from typing import Any class AEngineTools(ABC): - @abc.abstractmethod + @abstractmethod def extract_from_url(self, url) -> tuple[int | str, Any | None] | tuple[None, None]: pass diff --git a/src/pyramid/data/a_search.py b/src/pyramid/data/a_search.py index 4909bd8..6359e5b 100644 --- a/src/pyramid/data/a_search.py +++ b/src/pyramid/data/a_search.py @@ -1,31 +1,30 @@ -import abc -from abc import ABC +from abc import ABC, abstractmethod from pyramid.data.track import TrackMinimal class ASearch(ABC): - @abc.abstractmethod + @abstractmethod async def search_track(self, search) -> TrackMinimal | None: pass - @abc.abstractmethod + @abstractmethod async def search_tracks(self, search, limit: int | None = None) -> list[TrackMinimal] | None: pass - @abc.abstractmethod + @abstractmethod async def get_playlist_tracks(self, playlist_name) -> list[TrackMinimal] | None: pass - @abc.abstractmethod + @abstractmethod async def get_album_tracks(self, album_name) -> list[TrackMinimal] | None: pass - @abc.abstractmethod + @abstractmethod async def get_top_artist(self, artist_name, limit=10) -> list[TrackMinimal] | None: pass - @abc.abstractmethod + @abstractmethod async def get_by_url( self, url ) -> tuple[list[TrackMinimal], list[TrackMinimal]] | TrackMinimal | None: @@ -33,23 +32,23 @@ async def get_by_url( class ASearchId(ABC): - @abc.abstractmethod + @abstractmethod async def get_track_by_id(self, track_id: int | str) -> TrackMinimal | None: pass - @abc.abstractmethod + @abstractmethod async def get_playlist_tracks_by_id( self, playlist_id: int | str ) -> tuple[list[TrackMinimal], list[TrackMinimal]] | None: pass - @abc.abstractmethod + @abstractmethod async def get_album_tracks_by_id( self, album_id: int | str ) -> tuple[list[TrackMinimal], list[TrackMinimal]] | None: pass - @abc.abstractmethod + @abstractmethod async def get_top_artist_by_id( self, artist_id: int | str, limit: int | None = None ) -> tuple[list[TrackMinimal], list[TrackMinimal]] | None: diff --git a/src/pyramid/data/functional/main.py b/src/pyramid/data/functional/main.py index e825ad1..ddc500f 100644 --- a/src/pyramid/data/functional/main.py +++ b/src/pyramid/data/functional/main.py @@ -6,7 +6,7 @@ from datetime import datetime from threading import Thread -import pyramid.tools.utils as tools +from pyramid.tools import utils from pyramid.data.functional.application_info import ApplicationInfo from pyramid.connector.discord.bot import DiscordBot from pyramid.client.server import SocketServer @@ -43,7 +43,7 @@ def logs(self): self.logger = logging.getLogger() # Deletion of log files over 10 - tools.keep_latest_files(log_dir, 10, "error") + utils.keep_latest_files(log_dir, 10, "error") def config(self): # Config load @@ -60,7 +60,7 @@ def open_socket(self): def clean_data(self): # Songs folder clear - tools.clear_directory(self._config.deezer__folder) + utils.clear_directory(self._config.deezer__folder) def start(self): # Discord Bot Instance diff --git a/src/pyramid/data/functional/messages/message_sender.py b/src/pyramid/data/functional/messages/message_sender.py index 48743f6..9723b76 100644 --- a/src/pyramid/data/functional/messages/message_sender.py +++ b/src/pyramid/data/functional/messages/message_sender.py @@ -1,6 +1,6 @@ import asyncio -import pyramid.tools.utils as tools +from pyramid.tools import utils from discord import Interaction, Message, WebhookMessage from discord.utils import MISSING from discord.abc import Messageable @@ -97,7 +97,7 @@ async def add_code_message(self, content: str, prefix=None, suffix=None): max_length -= len(sep) * 2 first_max_length = max_length - (len(prefix) if prefix is not None else 0) - substrings_generator = tools.split_string_by_length(content, max_length, first_max_length) + substrings_generator = utils.split_string_by_length(content, max_length, first_max_length) first_substring = next(substrings_generator, None) if first_substring is None: @@ -167,7 +167,7 @@ async def _get_last_channel_message(self) -> Message: def _tuncate_msg_if_overflow(self, content: str) -> str: if content == MISSING or content == "": return content - new_content, is_used = tools.substring_with_end_msg( + new_content, is_used = utils.substring_with_end_msg( content, MAX_MSG_LENGTH, "`{} more characters...`" ) if not is_used: diff --git a/src/pyramid/data/tracklist.py b/src/pyramid/data/tracklist.py index 40b2d1d..fbe8d92 100644 --- a/src/pyramid/data/tracklist.py +++ b/src/pyramid/data/tracklist.py @@ -1,7 +1,7 @@ import os import random -import pyramid.tools.utils as tools +from pyramid.tools import utils from pyramid.data.track import Track, TrackMinimal, TrackMinimalDeezer @@ -88,7 +88,7 @@ def get_length(self) -> str: return f"{length} track" def get_duration(self) -> str: - return tools.time_to_duration(sum(t.duration_seconds for t in self.__tracks)) + return utils.time_to_duration(sum(t.duration_seconds for t in self.__tracks)) def to_str(list_of_track: list[TrackMinimal] | list[TrackMinimalDeezer] | list[Track]) -> str: @@ -97,5 +97,5 @@ def to_str(list_of_track: list[TrackMinimal] | list[TrackMinimalDeezer] | list[T for i, track in enumerate(list_of_track) ] columns = ["n°", "Author", "Title", "Album"] - hsa = tools.human_string_array(data, columns, 50) + hsa = utils.human_string_array(data, columns, 50) return hsa diff --git a/src/pyramid/tools/logs_handler.py b/src/pyramid/tools/logs_handler.py index be754a2..c7936a5 100644 --- a/src/pyramid/tools/logs_handler.py +++ b/src/pyramid/tools/logs_handler.py @@ -4,7 +4,7 @@ import sys import coloredlogs -import pyramid.tools.utils as tools +from pyramid.tools import utils from pyramid.data.functional.application_info import ApplicationInfo from pyramid.data.environment import Environment @@ -32,7 +32,7 @@ def log_to_console(self): def log_to_file(self): log_filename = os.path.join(self.__logs_dir, self.__log_filename) - tools.create_parent_directories(log_filename) + utils.create_parent_directories(log_filename) file_handler = logging.handlers.RotatingFileHandler( filename=log_filename, @@ -47,7 +47,7 @@ def log_to_file(self): def log_to_file_exceptions(self): log_filename = os.path.join(self.__logs_dir, self.__error_filename) - tools.create_parent_directories(log_filename) + utils.create_parent_directories(log_filename) file_handler = logging.handlers.RotatingFileHandler( filename=log_filename, diff --git a/src/pyramid/tools/utils.py b/src/pyramid/tools/utils.py index ef01b2f..ccd6757 100644 --- a/src/pyramid/tools/utils.py +++ b/src/pyramid/tools/utils.py @@ -302,12 +302,3 @@ def format_bytes_speed(bytes_per_second: float): bytes_per_second /= factor return f"{round(bytes_per_second)} {units[-1]}" - -def camel_to_snake(name: str): - snake_case = "" - for i, char in enumerate(name): - if char.isupper() and i != 0: - snake_case += "_" + char.lower() - else: - snake_case += char.lower() - return snake_case From f2d89bb514d411d67ff255d83a268b35e017cf07 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Mon, 9 Sep 2024 22:42:24 +0200 Subject: [PATCH 03/32] chore: rename class --- src/pyramid/connector/discord/bot_cmd.py | 4 ++-- src/pyramid/connector/discord/commands/about_command.py | 6 +++--- .../discord/commands/api/{abstract_command.py => abc.py} | 2 +- .../commands/api/{annotation_command.py => annotation.py} | 6 +++--- .../commands/api/{parameters_command.py => parameters.py} | 0 .../commands/api/{register_command.py => register.py} | 5 ++--- src/pyramid/connector/discord/commands/help_command.py | 4 ++-- src/pyramid/connector/discord/commands/ping_command.py | 6 +++--- 8 files changed, 16 insertions(+), 17 deletions(-) rename src/pyramid/connector/discord/commands/api/{abstract_command.py => abc.py} (92%) rename src/pyramid/connector/discord/commands/api/{annotation_command.py => annotation.py} (83%) rename src/pyramid/connector/discord/commands/api/{parameters_command.py => parameters.py} (100%) rename src/pyramid/connector/discord/commands/api/{register_command.py => register.py} (71%) diff --git a/src/pyramid/connector/discord/bot_cmd.py b/src/pyramid/connector/discord/bot_cmd.py index c6ecc83..d4a410e 100644 --- a/src/pyramid/connector/discord/bot_cmd.py +++ b/src/pyramid/connector/discord/bot_cmd.py @@ -6,8 +6,8 @@ from discord import Guild, Interaction from discord.ext.commands import Bot from pyramid.connector.discord.commands.about_command import AboutCommand -from pyramid.connector.discord.commands.api.parameters_command import ParametersCommand -from pyramid.connector.discord.commands.api.register_command import register_commands +from pyramid.connector.discord.commands.api.parameters import ParametersCommand +from pyramid.connector.discord.commands.api.register import register_commands from pyramid.connector.discord.commands.help_command import HelpCommand from pyramid.connector.discord.commands.ping_command import PingCommand from pyramid.connector.discord.guild_cmd import GuildCmd diff --git a/src/pyramid/connector/discord/commands/about_command.py b/src/pyramid/connector/discord/commands/about_command.py index dd67227..5e64bbb 100644 --- a/src/pyramid/connector/discord/commands/about_command.py +++ b/src/pyramid/connector/discord/commands/about_command.py @@ -1,9 +1,9 @@ import time from discord import AppInfo, ClientUser, Color, Embed, Interaction from discord.user import BaseUser -from pyramid.connector.discord.commands.api.abstract_command import AbstractCommand -from pyramid.connector.discord.commands.api.annotation_command import discord_command -from pyramid.connector.discord.commands.api.parameters_command import ParametersCommand +from pyramid.connector.discord.commands.api.abc import AbstractCommand +from pyramid.connector.discord.commands.api.annotation import discord_command +from pyramid.connector.discord.commands.api.parameters import ParametersCommand from pyramid.tools import utils @discord_command(parameters=ParametersCommand(description="About the bot")) diff --git a/src/pyramid/connector/discord/commands/api/abstract_command.py b/src/pyramid/connector/discord/commands/api/abc.py similarity index 92% rename from src/pyramid/connector/discord/commands/api/abstract_command.py rename to src/pyramid/connector/discord/commands/api/abc.py index 02453ce..0e2908e 100644 --- a/src/pyramid/connector/discord/commands/api/abstract_command.py +++ b/src/pyramid/connector/discord/commands/api/abc.py @@ -7,7 +7,7 @@ from discord.ext.commands import Bot from discord.utils import MISSING -from pyramid.connector.discord.commands.api.parameters_command import ParametersCommand +from pyramid.connector.discord.commands.api.parameters import ParametersCommand class AbstractCommand(ABC): diff --git a/src/pyramid/connector/discord/commands/api/annotation_command.py b/src/pyramid/connector/discord/commands/api/annotation.py similarity index 83% rename from src/pyramid/connector/discord/commands/api/annotation_command.py rename to src/pyramid/connector/discord/commands/api/annotation.py index eecc152..3763589 100644 --- a/src/pyramid/connector/discord/commands/api/annotation_command.py +++ b/src/pyramid/connector/discord/commands/api/annotation.py @@ -2,9 +2,9 @@ from textwrap import TextWrapper from discord.utils import MISSING -from pyramid.connector.discord.commands.api.abstract_command import AbstractCommand -from pyramid.connector.discord.commands.api.parameters_command import ParametersCommand -from pyramid.connector.discord.commands.api.register_command import COMMANDS_AUTOREGISTRED +from pyramid.connector.discord.commands.api.abc import AbstractCommand +from pyramid.connector.discord.commands.api.parameters import ParametersCommand +from pyramid.connector.discord.commands.api.register import COMMANDS_AUTOREGISTRED def discord_command(*, parameters: ParametersCommand): diff --git a/src/pyramid/connector/discord/commands/api/parameters_command.py b/src/pyramid/connector/discord/commands/api/parameters.py similarity index 100% rename from src/pyramid/connector/discord/commands/api/parameters_command.py rename to src/pyramid/connector/discord/commands/api/parameters.py diff --git a/src/pyramid/connector/discord/commands/api/register_command.py b/src/pyramid/connector/discord/commands/api/register.py similarity index 71% rename from src/pyramid/connector/discord/commands/api/register_command.py rename to src/pyramid/connector/discord/commands/api/register.py index 2a3f6b1..90d2e28 100644 --- a/src/pyramid/connector/discord/commands/api/register_command.py +++ b/src/pyramid/connector/discord/commands/api/register.py @@ -1,12 +1,11 @@ import logging from discord.ext.commands import Bot -from pyramid.connector.discord.commands.api.abstract_command import AbstractCommand -from pyramid.connector.discord.commands.api.parameters_command import ParametersCommand +from pyramid.connector.discord.commands.api.abc import AbstractCommand +from pyramid.connector.discord.commands.api.parameters import ParametersCommand COMMANDS_AUTOREGISTRED: dict[type[AbstractCommand], ParametersCommand] = {} - def register_commands(bot: Bot, logger: logging.Logger, command_prefix: str | None = None): for cls, parameters in COMMANDS_AUTOREGISTRED.items(): class_instance = cls(parameters, bot, logger) diff --git a/src/pyramid/connector/discord/commands/help_command.py b/src/pyramid/connector/discord/commands/help_command.py index e9f9235..dea374a 100644 --- a/src/pyramid/connector/discord/commands/help_command.py +++ b/src/pyramid/connector/discord/commands/help_command.py @@ -1,8 +1,8 @@ from typing import List from discord import Color, Embed, Interaction from discord.app_commands import Command -from pyramid.connector.discord.commands.api.abstract_command import AbstractCommand, ParametersCommand -from pyramid.connector.discord.commands.api.annotation_command import discord_command +from pyramid.connector.discord.commands.api.abc import AbstractCommand, ParametersCommand +from pyramid.connector.discord.commands.api.annotation import discord_command @discord_command(parameters=ParametersCommand(description="List all commands")) class HelpCommand(AbstractCommand): diff --git a/src/pyramid/connector/discord/commands/ping_command.py b/src/pyramid/connector/discord/commands/ping_command.py index 6480bd6..55ee13c 100644 --- a/src/pyramid/connector/discord/commands/ping_command.py +++ b/src/pyramid/connector/discord/commands/ping_command.py @@ -1,8 +1,8 @@ import math from discord import Interaction -from pyramid.connector.discord.commands.api.abstract_command import AbstractCommand -from pyramid.connector.discord.commands.api.annotation_command import discord_command -from pyramid.connector.discord.commands.api.parameters_command import ParametersCommand +from pyramid.connector.discord.commands.api.abc import AbstractCommand +from pyramid.connector.discord.commands.api.annotation import discord_command +from pyramid.connector.discord.commands.api.parameters import ParametersCommand @discord_command(parameters=ParametersCommand(description="Displays response time between bot and Discord API")) class PingCommand(AbstractCommand): From c84dedc7f40875361912c7d02f0ca4c880a252f5 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Tue, 10 Sep 2024 01:16:10 +0200 Subject: [PATCH 04/32] WIP: service injector --- src/pyramid/connector/discord/bot.py | 4 +-- src/pyramid/connector/discord/bot_cmd.py | 16 ++++++++--- .../discord/commands/about_command.py | 10 ++++++- .../connector/discord/commands/api/abc.py | 17 +++++------ .../discord/commands/api/annotation.py | 4 +-- .../discord/commands/api/register.py | 20 ++++++++++--- .../discord/services/api/annotation.py | 12 ++++++++ .../discord/services/api/injector.py | 13 +++++++++ .../discord/services/api/register.py | 24 ++++++++++++++++ .../connector/discord/services/environment.py | 19 +++++++++++++ .../connector/discord/services/logger.py | 28 +++++++++++++++++++ .../data/functional/application_info.py | 11 ++++++-- src/pyramid/data/functional/main.py | 15 ++++++++++ 13 files changed, 169 insertions(+), 24 deletions(-) create mode 100644 src/pyramid/connector/discord/services/api/annotation.py create mode 100644 src/pyramid/connector/discord/services/api/injector.py create mode 100644 src/pyramid/connector/discord/services/api/register.py create mode 100644 src/pyramid/connector/discord/services/environment.py create mode 100644 src/pyramid/connector/discord/services/logger.py diff --git a/src/pyramid/connector/discord/bot.py b/src/pyramid/connector/discord/bot.py index 352d26e..df49047 100644 --- a/src/pyramid/connector/discord/bot.py +++ b/src/pyramid/connector/discord/bot.py @@ -42,7 +42,6 @@ def __init__(self, logger: logging.Logger, information: ApplicationInfo, config: self.__ffmpeg = config.discord__ffmpeg self.__environment: Environment = config.mode self.__engine_source = EngineSource(config) - self.__started = time.time() intents = discord.Intents.default() # intents.members = True @@ -68,8 +67,7 @@ def create(self, health: HealthModules): self.__get_guild_cmd, self.__logger, self.__information, - self.__environment, - self.__started, + self.__environment ) self._health = health diff --git a/src/pyramid/connector/discord/bot_cmd.py b/src/pyramid/connector/discord/bot_cmd.py index d4a410e..e5fd64f 100644 --- a/src/pyramid/connector/discord/bot_cmd.py +++ b/src/pyramid/connector/discord/bot_cmd.py @@ -24,20 +24,28 @@ def __init__( get_guild_cmd: Callable[[Guild], GuildCmd], logger: Logger, info: ApplicationInfo, - environment: Environment, - started: float, + environment: Environment ): self.__bot = bot self.__get_guild_cmd = get_guild_cmd self.__logger = logger self.__info = info self.__environment = environment - self.__started = started def register(self): bot = self.__bot - register_commands(self.__bot, self.__logger, self.__environment.name.lower()) + services: dict[str, object] = dict() + + service = self.__environment + service_name = service.__class__.__name__ + services[service_name] = service + + service = self.__info + service_name = service.__class__.__name__ + services[service_name] = service + + register_commands(services, self.__bot, self.__logger, self.__environment.name.lower()) # ping = PingCommand(ParametersCommand("ping"), self.__bot, self.__logger) # ping.register(self.__environment.name.lower()) diff --git a/src/pyramid/connector/discord/commands/about_command.py b/src/pyramid/connector/discord/commands/about_command.py index 5e64bbb..79bfec0 100644 --- a/src/pyramid/connector/discord/commands/about_command.py +++ b/src/pyramid/connector/discord/commands/about_command.py @@ -1,9 +1,12 @@ + import time from discord import AppInfo, ClientUser, Color, Embed, Interaction from discord.user import BaseUser from pyramid.connector.discord.commands.api.abc import AbstractCommand from pyramid.connector.discord.commands.api.annotation import discord_command from pyramid.connector.discord.commands.api.parameters import ParametersCommand +from pyramid.data.environment import Environment +from pyramid.data.functional.application_info import ApplicationInfo from pyramid.tools import utils @discord_command(parameters=ParametersCommand(description="About the bot")) @@ -15,6 +18,11 @@ class AboutCommand(AbstractCommand): # self.__environment = environment # self.__info = info + def injectService(self, environment: Environment, info: ApplicationInfo): + self.__environment = environment + self.__info = info + # self.logger.info("Injected !") + async def execute(self, ctx: Interaction): await ctx.response.defer(thinking=True) bot_user: ClientUser | None @@ -63,7 +71,7 @@ async def execute(self, ctx: Interaction): ) embed.add_field( name="Uptime", - value=utils.time_to_duration(int(round(time.time() - self.__started))), + value=utils.time_to_duration(int(round(time.time() - self.__info.get_started_at()))), inline=True, ) diff --git a/src/pyramid/connector/discord/commands/api/abc.py b/src/pyramid/connector/discord/commands/api/abc.py index 0e2908e..d65afdb 100644 --- a/src/pyramid/connector/discord/commands/api/abc.py +++ b/src/pyramid/connector/discord/commands/api/abc.py @@ -1,24 +1,18 @@ -from abc import ABC, abstractmethod -import copy import logging +from abc import ABC, abstractmethod from typing import Optional from discord import Interaction from discord.app_commands import Command from discord.ext.commands import Bot -from discord.utils import MISSING from pyramid.connector.discord.commands.api.parameters import ParametersCommand class AbstractCommand(ABC): def __init__(self, parameters: ParametersCommand, bot: Bot, logger: logging.Logger): + self.parameters = parameters self.bot = bot self.logger = logger - self.parameters = parameters - - @abstractmethod - async def execute(self, ctx: Interaction): - pass def register(self, command_prefix: Optional[str] = None): if command_prefix is not None: @@ -35,3 +29,10 @@ def register(self, command_prefix: Optional[str] = None): ) # self.bot.tree.add_command(command, guilds=self.parameters.guilds) self.bot.tree.add_command(command) + + def injectService(self): + pass + + @abstractmethod + async def execute(self, ctx: Interaction): + pass diff --git a/src/pyramid/connector/discord/commands/api/annotation.py b/src/pyramid/connector/discord/commands/api/annotation.py index 3763589..ca1ed58 100644 --- a/src/pyramid/connector/discord/commands/api/annotation.py +++ b/src/pyramid/connector/discord/commands/api/annotation.py @@ -4,7 +4,7 @@ from pyramid.connector.discord.commands.api.abc import AbstractCommand from pyramid.connector.discord.commands.api.parameters import ParametersCommand -from pyramid.connector.discord.commands.api.register import COMMANDS_AUTOREGISTRED +from pyramid.connector.discord.commands.api.register import COMMANDS_TO_REGISTER def discord_command(*, parameters: ParametersCommand): @@ -23,7 +23,7 @@ def decorator(cls): parameters.description = '…' else: parameters.description = _shorten(cls.__doc__) - COMMANDS_AUTOREGISTRED[cls] = parameters + COMMANDS_TO_REGISTER[cls] = parameters return cls return decorator diff --git a/src/pyramid/connector/discord/commands/api/register.py b/src/pyramid/connector/discord/commands/api/register.py index 90d2e28..ed4c844 100644 --- a/src/pyramid/connector/discord/commands/api/register.py +++ b/src/pyramid/connector/discord/commands/api/register.py @@ -1,13 +1,25 @@ +import inspect import logging from discord.ext.commands import Bot from pyramid.connector.discord.commands.api.abc import AbstractCommand from pyramid.connector.discord.commands.api.parameters import ParametersCommand -COMMANDS_AUTOREGISTRED: dict[type[AbstractCommand], ParametersCommand] = {} +COMMANDS_TO_REGISTER: dict[type[AbstractCommand], ParametersCommand] = {} -def register_commands(bot: Bot, logger: logging.Logger, command_prefix: str | None = None): - for cls, parameters in COMMANDS_AUTOREGISTRED.items(): +def register_commands(services: dict[str, object], bot: Bot, logger: logging.Logger, command_prefix: str | None = None): + for cls, parameters in COMMANDS_TO_REGISTER.items(): class_instance = cls(parameters, bot, logger) class_instance.register(command_prefix) - logger.info("%s - %s" % (vars(cls), vars(parameters))) + # logger.info("%s - %s" % (vars(cls), vars(parameters))) + # logger.info("services %s" % ", ".join(services.keys())) + + signature = inspect.signature(class_instance.injectService) + params = list(signature.parameters.values()) + # for param in params: + # logger.info("param %s" % param.annotation) + + # logger.info("params %s" % ", ".join(params)) + dependencies = [services[param.annotation.__name__] for param in params] + # logger.info("dependencies %s" % (vars(dependencies))) + class_instance.injectService(*dependencies) diff --git a/src/pyramid/connector/discord/services/api/annotation.py b/src/pyramid/connector/discord/services/api/annotation.py new file mode 100644 index 0000000..f5b4996 --- /dev/null +++ b/src/pyramid/connector/discord/services/api/annotation.py @@ -0,0 +1,12 @@ +from pyramid.connector.discord.services.api.register import SERVICE_TO_REGISTER + + +def pyramid_service(): + def decorator(cls): + # if not issubclass(cls, AbstractService): + # raise TypeError(f"Class {cls.__name__} must extend from AbstractListener") + + class_name = cls.__name__ + SERVICE_TO_REGISTER[class_name] = cls + return cls + return decorator diff --git a/src/pyramid/connector/discord/services/api/injector.py b/src/pyramid/connector/discord/services/api/injector.py new file mode 100644 index 0000000..adb4cfa --- /dev/null +++ b/src/pyramid/connector/discord/services/api/injector.py @@ -0,0 +1,13 @@ + +import logging +from abc import ABC, abstractmethod +from discord.ext.commands import Bot + +class ServiceInjector(ABC): + + def __init__(self, bot: Bot, logger: logging.Logger): + self.bot = bot + self.logger = logger + + def injectService(self): + pass diff --git a/src/pyramid/connector/discord/services/api/register.py b/src/pyramid/connector/discord/services/api/register.py new file mode 100644 index 0000000..bfd201e --- /dev/null +++ b/src/pyramid/connector/discord/services/api/register.py @@ -0,0 +1,24 @@ +import logging +from discord.ext.commands import Bot +from pyramid.connector.discord.services.api.injector import ServiceInjector + + +SERVICE_TO_REGISTER: dict[str, type[object]] = {} +SERVICE_REGISTRED: dict[str, object] = {} + +def register_services(bot: Bot, logger: logging.Logger): + for name, cls in SERVICE_TO_REGISTER.items(): + if issubclass(cls, ServiceInjector): + class_instance = cls(bot, logger) + else: + class_instance = cls() + SERVICE_REGISTRED[name] = class_instance + logger.info("SERVICE_REGISTRED %s" % name) + +def get_service(name: str): + return SERVICE_REGISTRED[name] + +def define_bot(bot: Bot): + for _, instance in SERVICE_REGISTRED.items(): + if isinstance(instance, ServiceInjector): + instance.bot = bot diff --git a/src/pyramid/connector/discord/services/environment.py b/src/pyramid/connector/discord/services/environment.py new file mode 100644 index 0000000..ecc41b8 --- /dev/null +++ b/src/pyramid/connector/discord/services/environment.py @@ -0,0 +1,19 @@ + +from pyramid.connector.discord.services.api.annotation import pyramid_service +from pyramid.data.environment import Environment + + +@pyramid_service() +class EnvironmentService: + + def __init__(self): + self.__type: Environment = Environment.PRODUCTION + + def get_type(self): + return self.__type + + def get_type_name(self): + return self.__type.name.capitalize() + + def set_type(self, environnement: Environment): + self.__type = environnement diff --git a/src/pyramid/connector/discord/services/logger.py b/src/pyramid/connector/discord/services/logger.py new file mode 100644 index 0000000..9dfe878 --- /dev/null +++ b/src/pyramid/connector/discord/services/logger.py @@ -0,0 +1,28 @@ +import logging + +from pyramid.connector.discord.services.api.annotation import pyramid_service + + +@pyramid_service() +class LoggerService: + + def __init__(self): + self.__logger = logging.getLogger() + + def critical(self, msg, *args, **kwargs): + self.__logger.critical(msg, *args, **kwargs) + + def error(self, msg, *args, **kwargs): + self.__logger.error(msg, *args, **kwargs) + + def warning(self, msg, *args, **kwargs): + self.__logger.warning(msg, *args, **kwargs) + + def info(self, msg, *args, **kwargs): + self.__logger.info(msg, *args, **kwargs) + + def debug(self, msg, *args, **kwargs): + self.__logger.debug(msg, *args, **kwargs) + + def log(self, level, msg, *args, **kwargs): + self.__logger.log(msg, level, *args, **kwargs) diff --git a/src/pyramid/data/functional/application_info.py b/src/pyramid/data/functional/application_info.py index 12ae19d..3a36219 100644 --- a/src/pyramid/data/functional/application_info.py +++ b/src/pyramid/data/functional/application_info.py @@ -1,24 +1,31 @@ -import json import os import platform import subprocess +import time +from pyramid.connector.discord.services.api.annotation import pyramid_service + +@pyramid_service() class ApplicationInfo: def __init__(self): self.__name = "pyramid" self.__os = self.__detect_os().lower() self.__version = os.getenv("PROJECT_VERSION") + self.__started_at = time.time() def get_name(self): return self.__name.capitalize() def get_version(self): - return f"v{self.__version}" + return self.__version def get_os(self): return self.__os + def get_started_at(self): + return self.__started_at + def __detect_os(self) -> str: os_name = platform.system() if os_name == "Linux": diff --git a/src/pyramid/data/functional/main.py b/src/pyramid/data/functional/main.py index ddc500f..07099c2 100644 --- a/src/pyramid/data/functional/main.py +++ b/src/pyramid/data/functional/main.py @@ -6,6 +6,8 @@ from datetime import datetime from threading import Thread +from pyramid.connector.discord.services.api.register import get_service, register_services +from pyramid.connector.discord.services.environment import EnvironmentService from pyramid.tools import utils from pyramid.data.functional.application_info import ApplicationInfo from pyramid.connector.discord.bot import DiscordBot @@ -21,6 +23,7 @@ def __init__(self): # Program information self._info = ApplicationInfo() self._health = HealthModules() + self._discord_bot = None # Argument management def args(self): @@ -89,6 +92,18 @@ def handle_signal(signum: int, frame): logging.info(f"Received signal {signum}. shutting down ...") asyncio.run_coroutine_threadsafe(shutdown(loop), loop) + # -- Service [TEMP] + self._discord_bot = discord_bot + if self._discord_bot is None: + raise Exception("Bot is not connected") + register_services(self._discord_bot.bot, self.logger) + environment_service = get_service("EnvironmentService") + self.logger.info("environment_service %s" % environment_service) + if not isinstance(environment_service, EnvironmentService): + raise Exception("environment_service is not from type EnvironmentService, got %s" % type(environment_service)) + environment_service.set_type(self._config.mode) + # -- + previous_handler = signal.signal(signal.SIGTERM, handle_signal) # Connect bot to Discord servers in his own thread From 74262790233b8e01e03ed3e81f1b03eb2cfd3db9 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Mon, 16 Sep 2024 01:08:54 +0200 Subject: [PATCH 05/32] feat: top level service injector --- src/pyramid/api/services/configuration.py | 15 ++ src/pyramid/api/services/information.py | 8 + src/pyramid/api/services/logger.py | 36 ++++ src/pyramid/api/services/socket_server.py | 8 + src/pyramid/api/services/tools/annotation.py | 18 ++ src/pyramid/api/services/tools/exceptions.py | 11 ++ .../api => api/services/tools}/injector.py | 12 +- src/pyramid/api/services/tools/register.py | 158 ++++++++++++++++++ src/pyramid/connector/discord/bot.py | 12 +- src/pyramid/connector/discord/bot_cmd.py | 11 +- .../commands/{about_command.py => about.py} | 13 +- .../discord/commands/api/register.py | 25 --- .../commands/{help_command.py => help.py} | 4 +- .../commands/{ping_command.py => ping.py} | 6 +- .../discord/commands/{api => tools}/abc.py | 3 +- .../commands/{api => tools}/annotation.py | 10 +- .../commands/{api => tools}/parameters.py | 0 .../discord/commands/tools/register.py | 34 ++++ .../discord/services/api/annotation.py | 12 -- .../discord/services/api/register.py | 24 --- .../connector/discord/services/environment.py | 19 --- .../connector/discord/services/logger.py | 28 ---- .../data/functional/application_info.py | 4 - src/pyramid/data/functional/engine_source.py | 3 +- src/pyramid/data/functional/main.py | 79 ++++----- src/pyramid/services/clean.py | 19 +++ src/pyramid/services/configuration.py | 66 ++++++++ src/pyramid/services/discord.py | 62 +++++++ src/pyramid/services/information.py | 14 ++ src/pyramid/services/logger.py | 120 +++++++++++++ src/pyramid/services/logger_level.py | 37 ++++ src/pyramid/services/socket_server.py | 28 ++++ .../tools/configuration/configuration.py | 16 +- .../tools/configuration/configuration_load.py | 15 +- src/pyramid/tools/deprecated_class.py | 16 ++ src/pyramid/tools/logs_handler.py | 2 + src/startup.py | 6 - 37 files changed, 723 insertions(+), 231 deletions(-) create mode 100644 src/pyramid/api/services/configuration.py create mode 100644 src/pyramid/api/services/information.py create mode 100644 src/pyramid/api/services/logger.py create mode 100644 src/pyramid/api/services/socket_server.py create mode 100644 src/pyramid/api/services/tools/annotation.py create mode 100644 src/pyramid/api/services/tools/exceptions.py rename src/pyramid/{connector/discord/services/api => api/services/tools}/injector.py (55%) create mode 100644 src/pyramid/api/services/tools/register.py rename src/pyramid/connector/discord/commands/{about_command.py => about.py} (80%) delete mode 100644 src/pyramid/connector/discord/commands/api/register.py rename src/pyramid/connector/discord/commands/{help_command.py => help.py} (86%) rename src/pyramid/connector/discord/commands/{ping_command.py => ping.py} (61%) rename src/pyramid/connector/discord/commands/{api => tools}/abc.py (90%) rename src/pyramid/connector/discord/commands/{api => tools}/annotation.py (77%) rename src/pyramid/connector/discord/commands/{api => tools}/parameters.py (100%) create mode 100644 src/pyramid/connector/discord/commands/tools/register.py delete mode 100644 src/pyramid/connector/discord/services/api/annotation.py delete mode 100644 src/pyramid/connector/discord/services/api/register.py delete mode 100644 src/pyramid/connector/discord/services/environment.py delete mode 100644 src/pyramid/connector/discord/services/logger.py create mode 100644 src/pyramid/services/clean.py create mode 100644 src/pyramid/services/configuration.py create mode 100644 src/pyramid/services/discord.py create mode 100644 src/pyramid/services/information.py create mode 100644 src/pyramid/services/logger.py create mode 100644 src/pyramid/services/logger_level.py create mode 100644 src/pyramid/services/socket_server.py create mode 100644 src/pyramid/tools/deprecated_class.py diff --git a/src/pyramid/api/services/configuration.py b/src/pyramid/api/services/configuration.py new file mode 100644 index 0000000..acfe945 --- /dev/null +++ b/src/pyramid/api/services/configuration.py @@ -0,0 +1,15 @@ +from abc import ABC +from pyramid.data.environment import Environment + +class IConfigurationService(ABC): + + def __init__(self): + self.discord__token: str = "" + self.discord__ffmpeg: str = "" + self.deezer__arl: str = "" + self.deezer__folder: str = "" + self.spotify__client_id: str = "" + self.spotify__client_secret: str = "" + self.general__limit_tracks: int = 0 + self.mode: Environment = Environment.PRODUCTION + self.version: str = "" diff --git a/src/pyramid/api/services/information.py b/src/pyramid/api/services/information.py new file mode 100644 index 0000000..a878956 --- /dev/null +++ b/src/pyramid/api/services/information.py @@ -0,0 +1,8 @@ +from abc import ABC, abstractmethod +from pyramid.data.functional.application_info import ApplicationInfo + +class IInformationService(ABC): + + @abstractmethod + def get(self) -> ApplicationInfo: + pass diff --git a/src/pyramid/api/services/logger.py b/src/pyramid/api/services/logger.py new file mode 100644 index 0000000..f67915a --- /dev/null +++ b/src/pyramid/api/services/logger.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from logging import Logger + +class ILoggerService(ABC): + + @abstractmethod + def critical(self, msg, *args, **kwargs): + pass + + @abstractmethod + def error(self, msg, *args, **kwargs): + pass + + @abstractmethod + def warning(self, msg, *args, **kwargs): + pass + + @abstractmethod + def info(self, msg, *args, **kwargs): + pass + + @abstractmethod + def debug(self, msg, *args, **kwargs): + pass + + @abstractmethod + def log(self, level, msg, *args, **kwargs): + pass + + @abstractmethod + def getChild(self, suffix: str) -> Logger: + pass + + @abstractmethod + def getLogger(self) -> Logger: + pass diff --git a/src/pyramid/api/services/socket_server.py b/src/pyramid/api/services/socket_server.py new file mode 100644 index 0000000..eb35498 --- /dev/null +++ b/src/pyramid/api/services/socket_server.py @@ -0,0 +1,8 @@ +from abc import ABC, abstractmethod +from pyramid.data.functional.application_info import ApplicationInfo + +class ISocketServerService(ABC): + + @abstractmethod + def start(self): + pass diff --git a/src/pyramid/api/services/tools/annotation.py b/src/pyramid/api/services/tools/annotation.py new file mode 100644 index 0000000..f8360ea --- /dev/null +++ b/src/pyramid/api/services/tools/annotation.py @@ -0,0 +1,18 @@ +from typing import Optional +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.api.services.tools.register import ServiceRegister + + +def pyramid_service(*, interface: Optional[type] = None): + def decorator(cls): + class_name = cls.__name__ + if not issubclass(cls, ServiceInjector): + raise TypeError("Class %s must inherit from ServiceInjector" % class_name) + + service_name = class_name + if interface is not None: + service_name = interface.__name__ + + ServiceRegister.register_service(service_name, cls) + return cls + return decorator diff --git a/src/pyramid/api/services/tools/exceptions.py b/src/pyramid/api/services/tools/exceptions.py new file mode 100644 index 0000000..f181181 --- /dev/null +++ b/src/pyramid/api/services/tools/exceptions.py @@ -0,0 +1,11 @@ +class ServiceRegisterException(Exception): + pass + +class ServiceAlreadyRegisterException(ServiceRegisterException): + pass + +class ServiceAlreadyNotRegisterException(ServiceRegisterException): + pass + +class ServiceCicularDependencyException(ServiceRegisterException): + pass diff --git a/src/pyramid/connector/discord/services/api/injector.py b/src/pyramid/api/services/tools/injector.py similarity index 55% rename from src/pyramid/connector/discord/services/api/injector.py rename to src/pyramid/api/services/tools/injector.py index adb4cfa..8849087 100644 --- a/src/pyramid/connector/discord/services/api/injector.py +++ b/src/pyramid/api/services/tools/injector.py @@ -1,13 +1,13 @@ - -import logging from abc import ABC, abstractmethod from discord.ext.commands import Bot class ServiceInjector(ABC): - def __init__(self, bot: Bot, logger: logging.Logger): - self.bot = bot - self.logger = logger - def injectService(self): pass + + def start(self): + pass + + def stop(self): + pass diff --git a/src/pyramid/api/services/tools/register.py b/src/pyramid/api/services/tools/register.py new file mode 100644 index 0000000..302bc9b --- /dev/null +++ b/src/pyramid/api/services/tools/register.py @@ -0,0 +1,158 @@ + +from collections import defaultdict, deque +import importlib +import inspect +import pkgutil +from typing import Type, TypeVar +from pyramid.api.services.tools.exceptions import ServiceAlreadyNotRegisterException, ServiceAlreadyRegisterException, ServiceCicularDependencyException +from pyramid.api.services.tools.injector import ServiceInjector + +T = TypeVar('T') + +class ServiceRegister: + + __SERVICE_TO_REGISTER: dict[str, type[ServiceInjector]] = {} + __SERVICE_REGISTERED: dict[str, ServiceInjector] = {} + + @staticmethod + def register_service(name: str, type: type[object]): + if not issubclass(type, ServiceInjector): + raise TypeError("Service %s is not a subclass of ServiceInjector and cannot be initialized." % name) + if name in ServiceRegister.__SERVICE_TO_REGISTER: + already_class_name = ServiceRegister.__SERVICE_TO_REGISTER[name].__name__ + raise ServiceAlreadyRegisterException( + "Cannot register %s with %s, it is already registered with the class %s." + % (name, type.__name__, already_class_name) + ) + ServiceRegister.__SERVICE_TO_REGISTER[name] = type + + @staticmethod + def import_services(): + package_name = "pyramid.services" + package = importlib.import_module(package_name) + + for loader, module_name, is_pkg in pkgutil.iter_modules(package.__path__): + full_module_name = f"{package_name}.{module_name}" + module = importlib.import_module(full_module_name) + + @staticmethod + def create_services(): + for name, cls in ServiceRegister.__SERVICE_TO_REGISTER.items(): + class_instance = cls() + ServiceRegister.__SERVICE_REGISTERED[name] = class_instance + + @staticmethod + def inject_services(): + # Step 1: Create a graph of dependencies + dependency_graph = defaultdict(list) + indegree = defaultdict(int) # To track the number of dependencies + + # Create instances but delay injecting dependencies + for name, cls in ServiceRegister.__SERVICE_TO_REGISTER.items(): + class_instance = cls() + ServiceRegister.__SERVICE_REGISTERED[name] = class_instance + + # Step 2: Parse dependencies for each service + signature = inspect.signature(class_instance.injectService) + method_parameters = list(signature.parameters.values()) + + for method_parameter in method_parameters: + dependency_name = method_parameter.annotation.__name__ + if dependency_name not in ServiceRegister.__SERVICE_REGISTERED: + raise ServiceAlreadyNotRegisterException( + "Cannot register %s as a dependency for %s because the dependency is not registered." + % (dependency_name, name) + ) + # Add an edge in the dependency graph + dependency_graph[dependency_name].append(name) + indegree[name] += 1 + + # Step 3: Perform a topological sort to determine the order of instantiation + sorted_services = [] + queue = deque([service for service in ServiceRegister.__SERVICE_TO_REGISTER if indegree[service] == 0]) + + while queue: + service = queue.popleft() + sorted_services.append(service) + + for dependent in dependency_graph[service]: + indegree[dependent] -= 1 + if indegree[dependent] == 0: + queue.append(dependent) + + if len(sorted_services) != len(ServiceRegister.__SERVICE_TO_REGISTER): + unresolved_services = set(ServiceRegister.__SERVICE_TO_REGISTER) - set(sorted_services) + raise ServiceCicularDependencyException( + "Circular dependency detected! The following services are involved in a circular dependency: %s" + % ', '.join(unresolved_services) + ) + + # Step 4: Inject dependencies in the correct order + for service_name in sorted_services: + class_instance = ServiceRegister.__SERVICE_REGISTERED[service_name] + signature = inspect.signature(class_instance.injectService) + method_parameters = list(signature.parameters.values()) + + services_dependencies = [] + for method_parameter in method_parameters: + dependency_name = method_parameter.annotation.__name__ + dependency_instance = ServiceRegister.__SERVICE_REGISTERED[dependency_name] + services_dependencies.append(dependency_instance) + + class_instance.injectService(*services_dependencies) + + @staticmethod + def get_dependency_tree(): + # Step 1: Build dependency graph + dependency_graph = defaultdict(list) + for name, cls in ServiceRegister.__SERVICE_TO_REGISTER.items(): + class_instance = ServiceRegister.__SERVICE_REGISTERED[name] + + signature = inspect.signature(class_instance.injectService) + method_parameters = list(signature.parameters.values()) + + for method_parameter in method_parameters: + dependency_name = method_parameter.annotation.__name__ + dependency_graph[dependency_name].append(name) + + # Step 2: Internal buffer for storing the tree structure + buffer = [] + + def append_to_buffer(line): + buffer.append(line) + + # Step 3: Recursive function to build the tree structure + def build_tree(node, prefix="", last=True): + """ Recursively builds the tree structure and stores it in the buffer. """ + connector = "└── " if last else "├── " + append_to_buffer(prefix + connector + node) + + prefix += " " if last else "│ " + children = dependency_graph[node] + for i, child in enumerate(children): + build_tree(child, prefix, i == len(children) - 1) + + # Step 4: Find root services (those with no dependencies) + all_services = set(ServiceRegister.__SERVICE_TO_REGISTER.keys()) + dependent_services = set(dep for deps in dependency_graph.values() for dep in deps) + root_services = all_services - dependent_services + + if not root_services: + raise ServiceCicularDependencyException("No root services found. Possible circular dependencies.") + + # Step 5: Build the tree starting from each root service + for root in root_services: + build_tree(root) + + return "Services dependencies :\n" + "\n".join(buffer) + + + @staticmethod + def start_services(): + for name, class_instance in ServiceRegister.__SERVICE_REGISTERED.items(): + class_instance.start() + + @staticmethod + def get_service(class_type: Type[T]) -> T: + class_name = class_type.__name__ + return ServiceRegister.__SERVICE_REGISTERED[class_name] diff --git a/src/pyramid/connector/discord/bot.py b/src/pyramid/connector/discord/bot.py index df49047..db7ed2d 100644 --- a/src/pyramid/connector/discord/bot.py +++ b/src/pyramid/connector/discord/bot.py @@ -19,6 +19,7 @@ CommandError, ) from discord.app_commands.errors import AppCommandError, CommandInvokeError +from pyramid.api.services.configuration import IConfigurationService from pyramid.data.functional.application_info import ApplicationInfo from pyramid.data.environment import Environment from pyramid.data.guild_data import GuildData @@ -35,7 +36,7 @@ class DiscordBot: - def __init__(self, logger: logging.Logger, information: ApplicationInfo, config: Configuration): + def __init__(self, logger: logging.Logger, information: ApplicationInfo, config: IConfigurationService): self.__logger = logger self.__information = information self.__token = config.discord__token @@ -59,7 +60,8 @@ def __init__(self, logger: logging.Logger, information: ApplicationInfo, config: self.guilds_instances: Dict[int, GuildInstances] = {} - def create(self, health: HealthModules): + # def create(self, health: HealthModules): + def create(self): self.__logger.info("Discord bot creating with discord.py v%s ...", discord.__version__) self.listeners = BotListener(self.bot, self.__logger) self.cmd = BotCmd( @@ -69,7 +71,7 @@ def create(self, health: HealthModules): self.__information, self.__environment ) - self._health = health + # self._health = health @self.bot.event async def on_command_error(ctx: Context, error: CommandError): @@ -140,10 +142,10 @@ async def start(self): self.__logger.info("Discord bot login") await self.bot.login(self.__token) self.__logger.info("Discord bot connecting") - self._health.discord = True + # self._health.discord = True await self.bot.connect() except PrivilegedIntentsRequired as ex: - self._health.discord = False + # self._health.discord = False self.__logger.critical(ex) sys.exit(1) diff --git a/src/pyramid/connector/discord/bot_cmd.py b/src/pyramid/connector/discord/bot_cmd.py index e5fd64f..a752a0c 100644 --- a/src/pyramid/connector/discord/bot_cmd.py +++ b/src/pyramid/connector/discord/bot_cmd.py @@ -1,15 +1,9 @@ -import math from logging import Logger -import time from typing import Callable from discord import Guild, Interaction from discord.ext.commands import Bot -from pyramid.connector.discord.commands.about_command import AboutCommand -from pyramid.connector.discord.commands.api.parameters import ParametersCommand -from pyramid.connector.discord.commands.api.register import register_commands -from pyramid.connector.discord.commands.help_command import HelpCommand -from pyramid.connector.discord.commands.ping_command import PingCommand +from pyramid.connector.discord.commands.tools.register import CommandRegister from pyramid.connector.discord.guild_cmd import GuildCmd from pyramid.data.environment import Environment from pyramid.data.functional.application_info import ApplicationInfo @@ -45,7 +39,8 @@ def register(self): service_name = service.__class__.__name__ services[service_name] = service - register_commands(services, self.__bot, self.__logger, self.__environment.name.lower()) + CommandRegister.import_commands() + CommandRegister.create_commands(services, self.__bot, self.__logger, self.__environment.name.lower()) # ping = PingCommand(ParametersCommand("ping"), self.__bot, self.__logger) # ping.register(self.__environment.name.lower()) diff --git a/src/pyramid/connector/discord/commands/about_command.py b/src/pyramid/connector/discord/commands/about.py similarity index 80% rename from src/pyramid/connector/discord/commands/about_command.py rename to src/pyramid/connector/discord/commands/about.py index 79bfec0..765cbe2 100644 --- a/src/pyramid/connector/discord/commands/about_command.py +++ b/src/pyramid/connector/discord/commands/about.py @@ -2,9 +2,9 @@ import time from discord import AppInfo, ClientUser, Color, Embed, Interaction from discord.user import BaseUser -from pyramid.connector.discord.commands.api.abc import AbstractCommand -from pyramid.connector.discord.commands.api.annotation import discord_command -from pyramid.connector.discord.commands.api.parameters import ParametersCommand +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand from pyramid.data.environment import Environment from pyramid.data.functional.application_info import ApplicationInfo from pyramid.tools import utils @@ -12,16 +12,9 @@ @discord_command(parameters=ParametersCommand(description="About the bot")) class AboutCommand(AbstractCommand): - # def __init__(self, bot: Bot, logger: logging.Logger, started: float, environment: Environment, info: ApplicationInfo): - # super().__init__(bot, logger) - # self.__started = started - # self.__environment = environment - # self.__info = info - def injectService(self, environment: Environment, info: ApplicationInfo): self.__environment = environment self.__info = info - # self.logger.info("Injected !") async def execute(self, ctx: Interaction): await ctx.response.defer(thinking=True) diff --git a/src/pyramid/connector/discord/commands/api/register.py b/src/pyramid/connector/discord/commands/api/register.py deleted file mode 100644 index ed4c844..0000000 --- a/src/pyramid/connector/discord/commands/api/register.py +++ /dev/null @@ -1,25 +0,0 @@ - -import inspect -import logging -from discord.ext.commands import Bot -from pyramid.connector.discord.commands.api.abc import AbstractCommand -from pyramid.connector.discord.commands.api.parameters import ParametersCommand - -COMMANDS_TO_REGISTER: dict[type[AbstractCommand], ParametersCommand] = {} - -def register_commands(services: dict[str, object], bot: Bot, logger: logging.Logger, command_prefix: str | None = None): - for cls, parameters in COMMANDS_TO_REGISTER.items(): - class_instance = cls(parameters, bot, logger) - class_instance.register(command_prefix) - # logger.info("%s - %s" % (vars(cls), vars(parameters))) - # logger.info("services %s" % ", ".join(services.keys())) - - signature = inspect.signature(class_instance.injectService) - params = list(signature.parameters.values()) - # for param in params: - # logger.info("param %s" % param.annotation) - - # logger.info("params %s" % ", ".join(params)) - dependencies = [services[param.annotation.__name__] for param in params] - # logger.info("dependencies %s" % (vars(dependencies))) - class_instance.injectService(*dependencies) diff --git a/src/pyramid/connector/discord/commands/help_command.py b/src/pyramid/connector/discord/commands/help.py similarity index 86% rename from src/pyramid/connector/discord/commands/help_command.py rename to src/pyramid/connector/discord/commands/help.py index dea374a..47ba370 100644 --- a/src/pyramid/connector/discord/commands/help_command.py +++ b/src/pyramid/connector/discord/commands/help.py @@ -1,8 +1,8 @@ from typing import List from discord import Color, Embed, Interaction from discord.app_commands import Command -from pyramid.connector.discord.commands.api.abc import AbstractCommand, ParametersCommand -from pyramid.connector.discord.commands.api.annotation import discord_command +from pyramid.connector.discord.commands.tools.abc import AbstractCommand, ParametersCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command @discord_command(parameters=ParametersCommand(description="List all commands")) class HelpCommand(AbstractCommand): diff --git a/src/pyramid/connector/discord/commands/ping_command.py b/src/pyramid/connector/discord/commands/ping.py similarity index 61% rename from src/pyramid/connector/discord/commands/ping_command.py rename to src/pyramid/connector/discord/commands/ping.py index 55ee13c..23c23b2 100644 --- a/src/pyramid/connector/discord/commands/ping_command.py +++ b/src/pyramid/connector/discord/commands/ping.py @@ -1,8 +1,8 @@ import math from discord import Interaction -from pyramid.connector.discord.commands.api.abc import AbstractCommand -from pyramid.connector.discord.commands.api.annotation import discord_command -from pyramid.connector.discord.commands.api.parameters import ParametersCommand +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand @discord_command(parameters=ParametersCommand(description="Displays response time between bot and Discord API")) class PingCommand(AbstractCommand): diff --git a/src/pyramid/connector/discord/commands/api/abc.py b/src/pyramid/connector/discord/commands/tools/abc.py similarity index 90% rename from src/pyramid/connector/discord/commands/api/abc.py rename to src/pyramid/connector/discord/commands/tools/abc.py index d65afdb..f1f4bd0 100644 --- a/src/pyramid/connector/discord/commands/api/abc.py +++ b/src/pyramid/connector/discord/commands/tools/abc.py @@ -5,7 +5,7 @@ from discord.app_commands import Command from discord.ext.commands import Bot -from pyramid.connector.discord.commands.api.parameters import ParametersCommand +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand class AbstractCommand(ABC): @@ -27,6 +27,7 @@ def register(self, command_prefix: Optional[str] = None): auto_locale_strings=self.parameters.auto_locale_strings, extras=self.parameters.extras, ) + # TODO check this usage # self.bot.tree.add_command(command, guilds=self.parameters.guilds) self.bot.tree.add_command(command) diff --git a/src/pyramid/connector/discord/commands/api/annotation.py b/src/pyramid/connector/discord/commands/tools/annotation.py similarity index 77% rename from src/pyramid/connector/discord/commands/api/annotation.py rename to src/pyramid/connector/discord/commands/tools/annotation.py index ca1ed58..aaeb8d9 100644 --- a/src/pyramid/connector/discord/commands/api/annotation.py +++ b/src/pyramid/connector/discord/commands/tools/annotation.py @@ -2,15 +2,15 @@ from textwrap import TextWrapper from discord.utils import MISSING -from pyramid.connector.discord.commands.api.abc import AbstractCommand -from pyramid.connector.discord.commands.api.parameters import ParametersCommand -from pyramid.connector.discord.commands.api.register import COMMANDS_TO_REGISTER +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.connector.discord.commands.tools.register import CommandRegister def discord_command(*, parameters: ParametersCommand): def decorator(cls): if not issubclass(cls, AbstractCommand): - raise TypeError(f"Class {cls.__name__} must extend from AbstractCommand") + raise TypeError("Class %s must extend from AbstractCommand" % cls.__name__) if parameters.name is MISSING: class_name = cls.__name__ @@ -23,7 +23,7 @@ def decorator(cls): parameters.description = '…' else: parameters.description = _shorten(cls.__doc__) - COMMANDS_TO_REGISTER[cls] = parameters + CommandRegister.register_command(cls, parameters) return cls return decorator diff --git a/src/pyramid/connector/discord/commands/api/parameters.py b/src/pyramid/connector/discord/commands/tools/parameters.py similarity index 100% rename from src/pyramid/connector/discord/commands/api/parameters.py rename to src/pyramid/connector/discord/commands/tools/parameters.py diff --git a/src/pyramid/connector/discord/commands/tools/register.py b/src/pyramid/connector/discord/commands/tools/register.py new file mode 100644 index 0000000..bea32b3 --- /dev/null +++ b/src/pyramid/connector/discord/commands/tools/register.py @@ -0,0 +1,34 @@ +import importlib +import inspect +import logging +import pkgutil +from discord.ext.commands import Bot +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand + +class CommandRegister: + + __COMMANDS_TO_REGISTER: dict[type[AbstractCommand], ParametersCommand] = {} + + @staticmethod + def register_command(type: type[AbstractCommand], parameterCommand: ParametersCommand): + CommandRegister.__COMMANDS_TO_REGISTER[type] = parameterCommand + + @staticmethod + def import_commands(): + package_name = "pyramid.connector.discord.commands" + package = importlib.import_module(package_name) + + for loader, module_name, is_pkg in pkgutil.iter_modules(package.__path__): + full_module_name = f"{package_name}.{module_name}" + module = importlib.import_module(full_module_name) + + @staticmethod + def create_commands(services: dict[str, object], bot: Bot, logger: logging.Logger, command_prefix: str | None = None): + for cls, parameters in CommandRegister.__COMMANDS_TO_REGISTER.items(): + class_instance = cls(parameters, bot, logger) + class_instance.register(command_prefix) + signature = inspect.signature(class_instance.injectService) + parameters = list(signature.parameters.values()) + dependencies = [services[param.annotation.__name__] for param in parameters] + class_instance.injectService(*dependencies) diff --git a/src/pyramid/connector/discord/services/api/annotation.py b/src/pyramid/connector/discord/services/api/annotation.py deleted file mode 100644 index f5b4996..0000000 --- a/src/pyramid/connector/discord/services/api/annotation.py +++ /dev/null @@ -1,12 +0,0 @@ -from pyramid.connector.discord.services.api.register import SERVICE_TO_REGISTER - - -def pyramid_service(): - def decorator(cls): - # if not issubclass(cls, AbstractService): - # raise TypeError(f"Class {cls.__name__} must extend from AbstractListener") - - class_name = cls.__name__ - SERVICE_TO_REGISTER[class_name] = cls - return cls - return decorator diff --git a/src/pyramid/connector/discord/services/api/register.py b/src/pyramid/connector/discord/services/api/register.py deleted file mode 100644 index bfd201e..0000000 --- a/src/pyramid/connector/discord/services/api/register.py +++ /dev/null @@ -1,24 +0,0 @@ -import logging -from discord.ext.commands import Bot -from pyramid.connector.discord.services.api.injector import ServiceInjector - - -SERVICE_TO_REGISTER: dict[str, type[object]] = {} -SERVICE_REGISTRED: dict[str, object] = {} - -def register_services(bot: Bot, logger: logging.Logger): - for name, cls in SERVICE_TO_REGISTER.items(): - if issubclass(cls, ServiceInjector): - class_instance = cls(bot, logger) - else: - class_instance = cls() - SERVICE_REGISTRED[name] = class_instance - logger.info("SERVICE_REGISTRED %s" % name) - -def get_service(name: str): - return SERVICE_REGISTRED[name] - -def define_bot(bot: Bot): - for _, instance in SERVICE_REGISTRED.items(): - if isinstance(instance, ServiceInjector): - instance.bot = bot diff --git a/src/pyramid/connector/discord/services/environment.py b/src/pyramid/connector/discord/services/environment.py deleted file mode 100644 index ecc41b8..0000000 --- a/src/pyramid/connector/discord/services/environment.py +++ /dev/null @@ -1,19 +0,0 @@ - -from pyramid.connector.discord.services.api.annotation import pyramid_service -from pyramid.data.environment import Environment - - -@pyramid_service() -class EnvironmentService: - - def __init__(self): - self.__type: Environment = Environment.PRODUCTION - - def get_type(self): - return self.__type - - def get_type_name(self): - return self.__type.name.capitalize() - - def set_type(self, environnement: Environment): - self.__type = environnement diff --git a/src/pyramid/connector/discord/services/logger.py b/src/pyramid/connector/discord/services/logger.py deleted file mode 100644 index 9dfe878..0000000 --- a/src/pyramid/connector/discord/services/logger.py +++ /dev/null @@ -1,28 +0,0 @@ -import logging - -from pyramid.connector.discord.services.api.annotation import pyramid_service - - -@pyramid_service() -class LoggerService: - - def __init__(self): - self.__logger = logging.getLogger() - - def critical(self, msg, *args, **kwargs): - self.__logger.critical(msg, *args, **kwargs) - - def error(self, msg, *args, **kwargs): - self.__logger.error(msg, *args, **kwargs) - - def warning(self, msg, *args, **kwargs): - self.__logger.warning(msg, *args, **kwargs) - - def info(self, msg, *args, **kwargs): - self.__logger.info(msg, *args, **kwargs) - - def debug(self, msg, *args, **kwargs): - self.__logger.debug(msg, *args, **kwargs) - - def log(self, level, msg, *args, **kwargs): - self.__logger.log(msg, level, *args, **kwargs) diff --git a/src/pyramid/data/functional/application_info.py b/src/pyramid/data/functional/application_info.py index 3a36219..2b452d0 100644 --- a/src/pyramid/data/functional/application_info.py +++ b/src/pyramid/data/functional/application_info.py @@ -3,10 +3,6 @@ import subprocess import time -from pyramid.connector.discord.services.api.annotation import pyramid_service - - -@pyramid_service() class ApplicationInfo: def __init__(self): self.__name = "pyramid" diff --git a/src/pyramid/data/functional/engine_source.py b/src/pyramid/data/functional/engine_source.py index 9c41144..4bfc568 100644 --- a/src/pyramid/data/functional/engine_source.py +++ b/src/pyramid/data/functional/engine_source.py @@ -1,6 +1,7 @@ from enum import Enum from typing import Dict +from pyramid.api.services.configuration import IConfigurationService from pyramid.connector.deezer.downloader import DeezerDownloader from pyramid.connector.deezer.search import DeezerSearch from pyramid.connector.spotify.search import SpotifySearch @@ -16,7 +17,7 @@ class SourceType(Enum): class EngineSource: - def __init__(self, config: Configuration): + def __init__(self, config: IConfigurationService): self.__downloader = DeezerDownloader(config.deezer__folder, config.deezer__arl) self.__deezer_search = DeezerSearch(config.general__limit_tracks) self.__spotify_search = SpotifySearch( diff --git a/src/pyramid/data/functional/main.py b/src/pyramid/data/functional/main.py index 07099c2..4b63419 100644 --- a/src/pyramid/data/functional/main.py +++ b/src/pyramid/data/functional/main.py @@ -3,18 +3,15 @@ import logging import sys import signal -from datetime import datetime from threading import Thread -from pyramid.connector.discord.services.api.register import get_service, register_services -from pyramid.connector.discord.services.environment import EnvironmentService -from pyramid.tools import utils +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.information import IInformationService +from pyramid.api.services.logger import ILoggerService +from pyramid.api.services.tools.register import ServiceRegister from pyramid.data.functional.application_info import ApplicationInfo from pyramid.connector.discord.bot import DiscordBot -from pyramid.client.server import SocketServer -from pyramid.data.health import HealthModules -from pyramid.tools.configuration.configuration import Configuration -from pyramid.tools.logs_handler import LogsHandler +from pyramid.tools import utils from pyramid.tools.custom_queue import Queue @@ -22,7 +19,7 @@ class Main: def __init__(self): # Program information self._info = ApplicationInfo() - self._health = HealthModules() + # self._health = HealthModules() self._discord_bot = None # Argument management @@ -35,41 +32,22 @@ def args(self): print(f"{self._info.get_version()}") sys.exit(0) - # Logs management - def logs(self): - current_datetime = datetime.now() - log_dir = "./logs" - log_name = f"./{current_datetime.strftime('%Y_%m_%d %H_%M')}.log" - - self._logs_handler = LogsHandler() - self._logs_handler.init(self._info, log_dir, log_name, "error.log") - self.logger = logging.getLogger() - - # Deletion of log files over 10 - utils.keep_latest_files(log_dir, 10, "error") - - def config(self): - # Config load - self._config = Configuration(self.logger) - self._config.load() - self._health.configuration = True - - self._logs_handler.set_log_level(self._config.mode) + def start(self): + ServiceRegister.import_services() + ServiceRegister.create_services() + ServiceRegister.inject_services() + ServiceRegister.start_services() - def open_socket(self): - self.socket_server = SocketServer(self.logger.getChild("socket"), self._health) - thread = Thread(name="Socket", target=self.socket_server.start_server, daemon=True) - thread.start() + logger = ServiceRegister.get_service(ILoggerService) + info = ServiceRegister.get_service(IInformationService) + config = ServiceRegister.get_service(IConfigurationService) - def clean_data(self): - # Songs folder clear - utils.clear_directory(self._config.deezer__folder) + logger.debug(ServiceRegister.get_dependency_tree()) - def start(self): # Discord Bot Instance - discord_bot = DiscordBot(self.logger.getChild("Discord"), self._info, self._config) + discord_bot = DiscordBot(logger.getChild("Discord"), info.get(), config) # Create bot - discord_bot.create(self._health) + discord_bot.create() loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() @@ -92,17 +70,16 @@ def handle_signal(signum: int, frame): logging.info(f"Received signal {signum}. shutting down ...") asyncio.run_coroutine_threadsafe(shutdown(loop), loop) - # -- Service [TEMP] - self._discord_bot = discord_bot - if self._discord_bot is None: - raise Exception("Bot is not connected") - register_services(self._discord_bot.bot, self.logger) - environment_service = get_service("EnvironmentService") - self.logger.info("environment_service %s" % environment_service) - if not isinstance(environment_service, EnvironmentService): - raise Exception("environment_service is not from type EnvironmentService, got %s" % type(environment_service)) - environment_service.set_type(self._config.mode) - # -- + # # -- Service [TEMP] + # self._discord_bot = discord_bot + # if self._discord_bot is None: + # raise Exception("Bot is not connected") + # # environment_service = get_service("EnvironmentService") + # # self.logger.info("environment_service %s" % environment_service) + # # if not isinstance(environment_service, EnvironmentService): + # # raise Exception("environment_service is not from type EnvironmentService, got %s" % type(environment_service)) + # # environment_service.set_type(self._config.mode) + # # -- previous_handler = signal.signal(signal.SIGTERM, handle_signal) @@ -112,7 +89,7 @@ def handle_signal(signum: int, frame): thread.join() signal.signal(signal.SIGTERM, previous_handler) - + def stop(self): logging.info("Wait for background tasks to stop") Queue.wait_for_end(5) diff --git a/src/pyramid/services/clean.py b/src/pyramid/services/clean.py new file mode 100644 index 0000000..693fe3d --- /dev/null +++ b/src/pyramid/services/clean.py @@ -0,0 +1,19 @@ +import logging +import coloredlogs + +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.api.services.configuration import IConfigurationService +from pyramid.tools import utils + +@pyramid_service() +class LoggerLevelService(ServiceInjector): + + def injectService(self, + configuration_service: IConfigurationService + ): + self.__configuration_service = configuration_service + + def start(self): + # Songs folder clear + utils.clear_directory(self.__configuration_service.deezer__folder) diff --git a/src/pyramid/services/configuration.py b/src/pyramid/services/configuration.py new file mode 100644 index 0000000..75b3205 --- /dev/null +++ b/src/pyramid/services/configuration.py @@ -0,0 +1,66 @@ +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.logger import ILoggerService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.tools import utils +from pyramid.tools.configuration.configuration_load import ConfigurationFromEnv, ConfigurationFromYAML +from pyramid.tools.configuration.configuration_save import ConfigurationToYAML + +@pyramid_service(interface=IConfigurationService) +class ConfigurationService(IConfigurationService, ConfigurationFromYAML, ConfigurationToYAML, ConfigurationFromEnv, ServiceInjector): + + def injectService(self, + logger_service: ILoggerService + ): + self.logger = logger_service + super().set_logger(logger_service) + + def start(self): + self.load() + + def load(self, config_file: str = "config.yml", use_env_vars: bool = True) -> bool: + """ + Loads configuration values from environment variables and/or a configuration file. + + Parameters: + - use_env_vars (bool): If True, loads configuration values from environment variables. + + Returns: + - bool: True if the loading process is successful, False otherwise. + """ + keys_length = utils.count_public_variables(self) + + # Load from environment variables if enabled + result_1 = True + if use_env_vars: + raw_values_env = self._get_env_vars() + result_values = self._transform_all(raw_values_env, keys_length) + result_1 = self._validate_all( + raw_values_env, result_values, "env vars", True, keys_length + ) + + # Load raw values from environment variables and config file + try: + raw_values_file = self._get_file_vars(config_file) + result_values = self._transform_all(raw_values_file, keys_length) + result_2 = self._validate_all( + raw_values_file, result_values, "file", keys_length=keys_length + ) + except FileNotFoundError as err: + if not result_1: + self.logger.critical( + "Unable to read configuration file '%s' :\n%s", config_file, err + ) + return False + result_2 = True + + return result_1 and result_2 + + def save(self, file_name): + """ + Saves the configuration values to a YAML file. + + Parameters: + - file_name (str): The name of the file to which the configuration will be saved. + """ + self._save_to_yaml(file_name) diff --git a/src/pyramid/services/discord.py b/src/pyramid/services/discord.py new file mode 100644 index 0000000..8092166 --- /dev/null +++ b/src/pyramid/services/discord.py @@ -0,0 +1,62 @@ +import asyncio +import logging +import signal +from threading import Thread + +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.information import IInformationService +from pyramid.api.services.logger import ILoggerService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.connector.discord.bot import DiscordBot +from pyramid.tools.custom_queue import Queue + + +# @pyramid_service() +class DiscordBotService(ServiceInjector): + + def injectService(self, + logger_service: ILoggerService, + information_service: IInformationService, + configuration_service: IConfigurationService + ): + self.__logger_service = logger_service + self.__information_service = information_service + self.__configuration_service = configuration_service + + def start(self): + self.discord_bot = DiscordBot(self.__logger_service.getChild("Discord"), self.__information_service.get(), self.__configuration_service) + self.discord_bot.create() + + loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + + def running(loop: asyncio.AbstractEventLoop): + asyncio.set_event_loop(loop) + + loop.create_task(self.discord_bot.start()) + try: + loop.run_forever() + finally: + loop.close() + + async def shutdown(loop: asyncio.AbstractEventLoop): + await self.discord_bot.stop() + loop.stop() + + def handle_signal(signum: int, frame): + logging.info("Received signal %d. shutting down ..." % signum) + asyncio.run_coroutine_threadsafe(shutdown(loop), loop) + + previous_handler = signal.signal(signal.SIGTERM, handle_signal) + + # Connect bot to Discord servers in his own thread + thread = Thread(name="Discord", target=running, args=(loop,)) + thread.start() + thread.join() + + signal.signal(signal.SIGTERM, previous_handler) + + def stop(self): + logging.info("Wait for background tasks to stop") + Queue.wait_for_end(5) + logging.info("Bye bye") diff --git a/src/pyramid/services/information.py b/src/pyramid/services/information.py new file mode 100644 index 0000000..ddc2d0c --- /dev/null +++ b/src/pyramid/services/information.py @@ -0,0 +1,14 @@ + +from pyramid.api.services.information import IInformationService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.data.functional.application_info import ApplicationInfo + +@pyramid_service(interface=IInformationService) +class InformationService(IInformationService, ServiceInjector): + + def __init__(self): + self._info = ApplicationInfo() + + def get(self) -> ApplicationInfo: + return self._info diff --git a/src/pyramid/services/logger.py b/src/pyramid/services/logger.py new file mode 100644 index 0000000..2a40695 --- /dev/null +++ b/src/pyramid/services/logger.py @@ -0,0 +1,120 @@ +import logging +import logging.handlers +import os +import sys +import coloredlogs +from datetime import datetime + +from pyramid.api.services.information import IInformationService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.data.environment import Environment +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.logger import ILoggerService +from pyramid.tools import utils + +from pyramid.tools import utils + +@pyramid_service(interface=ILoggerService) +class LoggerService(ILoggerService, ServiceInjector): + + def __init__(self): + self.__date = "%d/%m/%Y %H:%M:%S" + self.__console_format = "%(asctime)s %(levelname)s %(message)s" + self.__file_format = "[{asctime}] [{levelname:<8}] {name}: {message}" + self.logger = logging.getLogger() + self.logger.setLevel("INFO") + current_datetime = datetime.now() + self.__logs_dir = "./logs" + self.__log_filename = "./%s.log" % current_datetime.strftime('%Y_%m_%d %H_%M') + self.__error_filename = "./error.log" + + def injectService(self, + information_service: IInformationService + ): + self.__information_service = information_service + + def start(self): + self.__enable_log_to_console() + self.__enable_log_to_file() + self.__enable_log_to_file_exceptions() + + # Deletion of log files over 10 + utils.keep_latest_files(self.__logs_dir, 10, "error") + + def critical(self, msg, *args, **kwargs): + self.logger.critical(msg, *args, **kwargs) + + def error(self, msg, *args, **kwargs): + self.logger.error(msg, *args, **kwargs) + + def warning(self, msg, *args, **kwargs): + self.logger.warning(msg, *args, **kwargs) + + def info(self, msg, *args, **kwargs): + self.logger.info(msg, *args, **kwargs) + + def debug(self, msg, *args, **kwargs): + self.logger.debug(msg, *args, **kwargs) + + def log(self, level, msg, *args, **kwargs): + self.logger.log(msg, level, *args, **kwargs) + + def getLogger(self) -> logging.Logger: + return self.logger + + def getChild(self, suffix: str) -> logging.Logger: + return self.logger.getChild(suffix) + + def __enable_log_to_console(self): + coloredlogs.install(fmt=self.__console_format, datefmt=self.__date, isatty=True) + coloredlogs.set_level("INFO") + + def __enable_log_to_file(self): + log_filename = os.path.join(self.__logs_dir, self.__log_filename) + utils.create_parent_directories(log_filename) + + file_handler = logging.handlers.RotatingFileHandler( + filename=log_filename, + encoding="utf-8", + maxBytes=512 * 1024 * 1024, # 512 Mo + ) + + formatter = logging.Formatter(self.__file_format, self.__date, style="{") + file_handler.setFormatter(formatter) + + self.logger.addHandler(file_handler) + + def __enable_log_to_file_exceptions(self): + log_filename = os.path.join(self.__logs_dir, self.__error_filename) + utils.create_parent_directories(log_filename) + + file_handler = logging.handlers.RotatingFileHandler( + filename=log_filename, + encoding="utf-8", + maxBytes=10 * 1024 * 1024, # 10 Mo + backupCount=10, + ) + + formatter = logging.Formatter(self.__file_format, self.__date, style="{") + file_handler.setFormatter(formatter) + + # Retrieves warning exceptions and above + file_handler.setLevel("WARNING") + logging.getLogger().addHandler(file_handler) + + # Retrieves unhandled exceptions + self.logger_unhandled_exception = logging.getLogger("Unhandled Exception") + self.logger_unhandled_exception.addHandler(file_handler) + sys.excepthook = self.__handle_unhandled_exception + + def __handle_unhandled_exception(self, exc_type, exc_value, exc_traceback): + if issubclass(exc_type, KeyboardInterrupt): + # Will call default excepthook + sys.__excepthook__(exc_type, exc_value, exc_traceback) + return + info = self.__information_service.get() + # Create a critical level log message with info from the except hook. + self.logger_unhandled_exception.critical( + info, exc_info=(exc_type, exc_value, exc_traceback) + ) diff --git a/src/pyramid/services/logger_level.py b/src/pyramid/services/logger_level.py new file mode 100644 index 0000000..a31353c --- /dev/null +++ b/src/pyramid/services/logger_level.py @@ -0,0 +1,37 @@ +import logging +import coloredlogs + +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.data.environment import Environment +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.logger import ILoggerService + +@pyramid_service() +class LoggerLevelService(ServiceInjector): + + def injectService(self, + logger_service: ILoggerService, + configuration_service: IConfigurationService + ): + self.__logger_service = logger_service + self.__configuration_service = configuration_service + + def start(self): + logger = self.__logger_service.getLogger() + # logger_colored = self.__logger_service.getLogger() + logger_discord = logging.getLogger("discord") + # logger_aiohttpweb = logging.getLogger("aiohttpweb") + # logger_urllib3 = logging.getLogger("urllib3") + logger_asyncio = logging.getLogger("asyncio") + + logger_dict = logging.root.manager.loggerDict + active_loggers = [name for name in logger_dict if isinstance(logger_dict[name], logging.Logger)] + if self.__configuration_service.mode == Environment.PRODUCTION: + logger.setLevel("INFO") + coloredlogs.set_level("INFO") + else: + logger.setLevel("DEBUG") + coloredlogs.set_level("DEBUG") + logger_discord.setLevel("INFO") + logger_asyncio.setLevel("INFO") diff --git a/src/pyramid/services/socket_server.py b/src/pyramid/services/socket_server.py new file mode 100644 index 0000000..46f9f20 --- /dev/null +++ b/src/pyramid/services/socket_server.py @@ -0,0 +1,28 @@ + +from threading import Thread + +from pyramid.api.services.logger import ILoggerService +from pyramid.api.services.socket_server import ISocketServerService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.client.server import SocketServer +from pyramid.data.health import HealthModules + + +@pyramid_service(interface=ISocketServerService) +class SocketServerService(ISocketServerService, ServiceInjector): + + def __init__(self): + pass + + def injectService(self, logger_service: ILoggerService): + self.logger_service = logger_service + + def start(self): + self._health = HealthModules() + self._health.configuration = True + self._health.discord = True + + self.socket_server = SocketServer(self.logger_service.getChild("socket"), self._health) + thread = Thread(name="Socket", target=self.socket_server.start_server, daemon=True) + thread.start() diff --git a/src/pyramid/tools/configuration/configuration.py b/src/pyramid/tools/configuration/configuration.py index 0e72635..e4f4156 100644 --- a/src/pyramid/tools/configuration/configuration.py +++ b/src/pyramid/tools/configuration/configuration.py @@ -8,13 +8,7 @@ class Configuration(ConfigurationFromYAML, ConfigurationToYAML, ConfigurationFromEnv): - def __init__(self, logger: Logger | None = None): - """ - Initializes a Configuration object with default values and an optional logger. - - Parameters: - - logger (Logger): Optional logger for logging messages. If not provided, a default logger named "config" is used. - """ + def __init__(self): self.discord__token: str = "" self.discord__ffmpeg: str = "" self.deezer__arl: str = "" @@ -25,12 +19,6 @@ def __init__(self, logger: Logger | None = None): self.mode: Environment = Environment.PRODUCTION self.version: str = "" - if logger is None: - self.__logger = logging.getLogger("config") - else: - self.__logger = logger - super().__init__(self.__logger) - def load(self, config_file: str = "config.yml", use_env_vars: bool = True) -> bool: """ Loads configuration values from environment variables and/or a configuration file. @@ -61,7 +49,7 @@ def load(self, config_file: str = "config.yml", use_env_vars: bool = True) -> bo ) except FileNotFoundError as err: if not result_1: - self.__logger.critical( + self.logger.critical( "Unable to read configuration file '%s' :\n%s", config_file, err ) return False diff --git a/src/pyramid/tools/configuration/configuration_load.py b/src/pyramid/tools/configuration/configuration_load.py index c90d9b5..40012ca 100644 --- a/src/pyramid/tools/configuration/configuration_load.py +++ b/src/pyramid/tools/configuration/configuration_load.py @@ -1,11 +1,11 @@ import os import re from abc import ABC -from logging import Logger from typing import Any, Callable, List, Optional import yaml from pyramid.data.environment import Environment +from pyramid.api.services.logger import ILoggerService class ConfigurationFromEnv(ABC): @@ -51,8 +51,9 @@ def __load_env_vars_from_file(self, file) -> dict[str, str]: class ConfigurationFromYAML(ABC): - def __init__(self, logger: Logger): - self.__logger = logger + + def set_logger(self, logger: ILoggerService): + self.logger = logger def _get_file_vars(self, file_name: str) -> dict[str, str]: # Load from YAML file @@ -127,7 +128,7 @@ def _validate_all( continue elif isinstance(self_value, int) and self_value != 0: continue - self.__logger.warning(f"'{key_with_dot}' in {type} is not set") + self.logger.warning(f"'{key_with_dot}' in {type} is not set") continue errors_msg.append(f"'{key_with_dot}' with value '{value}' : {err}") @@ -138,7 +139,7 @@ def _validate_all( ] if len(key_not_used) != 0: - self.__logger.warning( + self.logger.warning( "Keys '%s' in %s configuration are not used", ", ".join(key_not_used), type ) @@ -150,9 +151,9 @@ def _validate_all( "\n- ".join(errors_msg), ) if ignore is True: - self.__logger.warning(full_error) + self.logger.warning(full_error) return True - self.__logger.critical(full_error) + self.logger.critical(full_error) return False def __check( diff --git a/src/pyramid/tools/deprecated_class.py b/src/pyramid/tools/deprecated_class.py new file mode 100644 index 0000000..11cd6da --- /dev/null +++ b/src/pyramid/tools/deprecated_class.py @@ -0,0 +1,16 @@ +import warnings +from functools import wraps + +def deprecated_class(cls): + @wraps(cls) + def new_init(*args, **kwargs): + warnings.warn( + f"{cls.__name__} is deprecated and will be removed in a future version.", + category=DeprecationWarning, + stacklevel=2 + ) + return original_init(*args, **kwargs) + + original_init = cls.__init__ + cls.__init__ = new_init + return cls diff --git a/src/pyramid/tools/logs_handler.py b/src/pyramid/tools/logs_handler.py index c7936a5..8c61334 100644 --- a/src/pyramid/tools/logs_handler.py +++ b/src/pyramid/tools/logs_handler.py @@ -7,8 +7,10 @@ from pyramid.tools import utils from pyramid.data.functional.application_info import ApplicationInfo from pyramid.data.environment import Environment +from pyramid.tools.deprecated_class import deprecated_class +@deprecated_class class LogsHandler: def __init__(self): self.__date = "%d/%m/%Y %H:%M:%S" diff --git a/src/startup.py b/src/startup.py index d42c18d..187fe19 100644 --- a/src/startup.py +++ b/src/startup.py @@ -2,12 +2,6 @@ def startup(): main = Main() - main.args() - main.logs() - main.config() - main.clean_data() - - main.open_socket() main.start() main.stop() From 0146a06a7925b368f7d031066319d6852d7db2a7 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Mon, 16 Sep 2024 01:11:04 +0200 Subject: [PATCH 06/32] fix: class name --- src/pyramid/services/clean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyramid/services/clean.py b/src/pyramid/services/clean.py index 693fe3d..2693c12 100644 --- a/src/pyramid/services/clean.py +++ b/src/pyramid/services/clean.py @@ -7,7 +7,7 @@ from pyramid.tools import utils @pyramid_service() -class LoggerLevelService(ServiceInjector): +class CleanService(ServiceInjector): def injectService(self, configuration_service: IConfigurationService From 379f1ed99e0b215ec362a5e51b584cba195ed0c8 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Mon, 16 Sep 2024 22:58:09 +0200 Subject: [PATCH 07/32] tests: fix old config class usage --- src/pyramid/tools/configuration/configuration.py | 5 ++--- tests/config_test.py | 7 +------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/pyramid/tools/configuration/configuration.py b/src/pyramid/tools/configuration/configuration.py index e4f4156..b824e01 100644 --- a/src/pyramid/tools/configuration/configuration.py +++ b/src/pyramid/tools/configuration/configuration.py @@ -1,12 +1,11 @@ -import logging -from logging import Logger - from pyramid.tools import utils +from pyramid.tools.deprecated_class import deprecated_class from pyramid.data.environment import Environment from pyramid.tools.configuration.configuration_load import ConfigurationFromEnv, ConfigurationFromYAML from pyramid.tools.configuration.configuration_save import ConfigurationToYAML +@deprecated_class class Configuration(ConfigurationFromYAML, ConfigurationToYAML, ConfigurationFromEnv): def __init__(self): self.discord__token: str = "" diff --git a/tests/config_test.py b/tests/config_test.py index 8dd7e1a..07f4f7b 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -1,6 +1,6 @@ import logging import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import patch from pyramid.tools.configuration.configuration import Configuration @@ -12,11 +12,6 @@ def test_default_logger(self): self.assertIsNotNone(logging.getLogger("config")) self.assertEqual(logging.getLogger("config").name, "config") - def test_custom_logger(self): - custom_logger = MagicMock() - self.config = Configuration(logger=custom_logger) - self.assertEqual(getattr(self.config, "_Configuration__logger"), custom_logger) - @patch("pyramid.tools.configuration.configuration_load.ConfigurationFromEnv._get_env_vars") @patch("pyramid.tools.configuration.configuration_load.ConfigurationFromYAML._get_file_vars") @patch("pyramid.tools.configuration.configuration_load.ConfigurationFromYAML._validate_all") From 0660dd55950232ba10aa076ddadec5669300d7cf Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Tue, 17 Sep 2024 02:01:07 +0200 Subject: [PATCH 08/32] chore: change MessageSender initialization --- src/pyramid/api/services/tools/register.py | 2 +- src/pyramid/connector/discord/bot.py | 12 +++++----- src/pyramid/connector/discord/bot_listener.py | 4 ++-- .../data/functional/application_info.py | 3 +++ src/pyramid/data/functional/main.py | 24 +++++++++---------- .../messages/message_sender_queued.py | 12 ++++------ src/pyramid/services/logger.py | 3 +++ src/pyramid/tools/custom_queue.py | 5 ++-- src/pyramid/tools/main_queue.py | 13 ++++++++++ tests/queue_test.py | 9 +++++++ 10 files changed, 56 insertions(+), 31 deletions(-) create mode 100644 src/pyramid/tools/main_queue.py diff --git a/src/pyramid/api/services/tools/register.py b/src/pyramid/api/services/tools/register.py index 302bc9b..cf909ba 100644 --- a/src/pyramid/api/services/tools/register.py +++ b/src/pyramid/api/services/tools/register.py @@ -144,7 +144,7 @@ def build_tree(node, prefix="", last=True): for root in root_services: build_tree(root) - return "Services dependencies :\n" + "\n".join(buffer) + return "Services tree :\n" + "\n".join(buffer) @staticmethod diff --git a/src/pyramid/connector/discord/bot.py b/src/pyramid/connector/discord/bot.py index db7ed2d..c1d42a2 100644 --- a/src/pyramid/connector/discord/bot.py +++ b/src/pyramid/connector/discord/bot.py @@ -86,13 +86,13 @@ async def on_command_error(ctx: Context, error: CommandError): "You dont have all the requirements or permissions for using this command :angry:" ) return - logging.error("Command error from on_command_error : %s", error) + self.__logger.error("Command error from on_command_error : %s", error) @self.bot.event async def on_error(event, *args, **kwargs): # message = args[0] # Message object # traceback.extract_stack - logging.error("Error from on_error : %s", traceback.format_exc()) + self.__logger.error("Error from on_error : %s", traceback.format_exc()) # await bot.send_message(message.channel, "You caused an error!") async def on_tree_error(ctx: Interaction, app_error: AppCommandError, /): @@ -105,7 +105,7 @@ async def on_tree_error(ctx: Interaction, app_error: AppCommandError, /): msg = "Error from on_tree_error" error = app_error trace = "".join(traceback.format_exception(type(error), error, error.__traceback__)) - logging.error("%s :\n%s", msg, trace) + self.__logger.error("%s :\n%s", msg, trace) discord_explanation = ":warning: You caused an error!" if isinstance(error, DiscordMessageException): @@ -132,7 +132,7 @@ async def on_tree_error(ctx: Interaction, app_error: AppCommandError, /): @self.bot.event async def on_command(ctx: Context): - logging.debug("on_command : %s", ctx.author) + self.__logger.debug("on_command : %s", ctx.author) self.cmd.register() self.listeners.register() @@ -151,9 +151,9 @@ async def start(self): async def stop(self): # self.bot.clear() - logging.info("Discord bot stop") + self.__logger.info("Discord bot stop") await self.bot.close() - logging.info("Discord bot stopped") + self.__logger.info("Discord bot stopped") def __get_guild_cmd(self, guild: Guild) -> GuildCmd: if guild.id not in self.guilds_instances: diff --git a/src/pyramid/connector/discord/bot_listener.py b/src/pyramid/connector/discord/bot_listener.py index 9d0a1b6..b17a894 100644 --- a/src/pyramid/connector/discord/bot_listener.py +++ b/src/pyramid/connector/discord/bot_listener.py @@ -22,11 +22,11 @@ async def on_ready(): # TODO changed to -> await bot.setup_hook() self.__logger.warning("Unable to get discord bot name") else: self.__logger.info("Discord bot name '%s'", bot.user.name) - self.__logger.info("------ GUILDS ------") + self.__logger.info("────── GUILDS ──────") for guild in bot.guilds: self.show_info_guild(guild) - self.__logger.info("----------------------") + self.__logger.info("─────────────────────") self.__logger.info("Discord bot ready") # await client.close() diff --git a/src/pyramid/data/functional/application_info.py b/src/pyramid/data/functional/application_info.py index 2b452d0..b132a3d 100644 --- a/src/pyramid/data/functional/application_info.py +++ b/src/pyramid/data/functional/application_info.py @@ -49,3 +49,6 @@ def __detect_linux_distro(self) -> str: except FileNotFoundError: pass return "Linux distribution information not available." + + def __str__(self): + return f"{self.__name.capitalize()} {self.get_version()} on {self.get_os()}" diff --git a/src/pyramid/data/functional/main.py b/src/pyramid/data/functional/main.py index 4b63419..f816b62 100644 --- a/src/pyramid/data/functional/main.py +++ b/src/pyramid/data/functional/main.py @@ -11,8 +11,8 @@ from pyramid.api.services.tools.register import ServiceRegister from pyramid.data.functional.application_info import ApplicationInfo from pyramid.connector.discord.bot import DiscordBot -from pyramid.tools import utils from pyramid.tools.custom_queue import Queue +from pyramid.tools.main_queue import MainQueue class Main: @@ -44,6 +44,8 @@ def start(self): logger.debug(ServiceRegister.get_dependency_tree()) + MainQueue.init() + # Discord Bot Instance discord_bot = DiscordBot(logger.getChild("Discord"), info.get(), config) # Create bot @@ -62,6 +64,7 @@ def running(loop: asyncio.AbstractEventLoop): finally: loop.close() + async def shutdown(loop: asyncio.AbstractEventLoop): await discord_bot.stop() loop.stop() @@ -70,17 +73,6 @@ def handle_signal(signum: int, frame): logging.info(f"Received signal {signum}. shutting down ...") asyncio.run_coroutine_threadsafe(shutdown(loop), loop) - # # -- Service [TEMP] - # self._discord_bot = discord_bot - # if self._discord_bot is None: - # raise Exception("Bot is not connected") - # # environment_service = get_service("EnvironmentService") - # # self.logger.info("environment_service %s" % environment_service) - # # if not isinstance(environment_service, EnvironmentService): - # # raise Exception("environment_service is not from type EnvironmentService, got %s" % type(environment_service)) - # # environment_service.set_type(self._config.mode) - # # -- - previous_handler = signal.signal(signal.SIGTERM, handle_signal) # Connect bot to Discord servers in his own thread @@ -89,6 +81,14 @@ def handle_signal(signum: int, frame): thread.join() signal.signal(signal.SIGTERM, previous_handler) + + # def running_async(discord_bot: DiscordBot): + # asyncio.run(discord_bot.start()) + + # thread = Thread(name="Discord", target=running_async, args=(discord_bot,)) + # thread.start() + # thread.join() + def stop(self): logging.info("Wait for background tasks to stop") diff --git a/src/pyramid/data/functional/messages/message_sender_queued.py b/src/pyramid/data/functional/messages/message_sender_queued.py index 2248d9c..7dde574 100644 --- a/src/pyramid/data/functional/messages/message_sender_queued.py +++ b/src/pyramid/data/functional/messages/message_sender_queued.py @@ -4,14 +4,10 @@ from discord import Interaction, Message, WebhookMessage from discord.utils import MISSING from pyramid.tools.custom_queue import Queue, QueueItem +from pyramid.tools.main_queue import MainQueue MAX_MSG_LENGTH = 2000 -queue = Queue(1, "MessageSender") -queue.start() -queue.register_to_wait_on_exit() - - class MessageSenderQueued(MessageSender): def __init__(self, ctx: Interaction): self.ctx = ctx @@ -22,7 +18,7 @@ def add_message( content: str = MISSING, callback: Callable[[Message | WebhookMessage], Any] | None = None, ) -> None: - queue.add( + MainQueue.instance.add( QueueItem("add_message", super().add_message, self.loop, callback, content=content) ) @@ -32,7 +28,7 @@ def edit_message( surname_content: str | None = None, callback: Callable[[Message | WebhookMessage], Any] | None = None, ): - queue.add( + MainQueue.instance.add( QueueItem( "response_message", super().edit_message, @@ -44,7 +40,7 @@ def edit_message( ) def add_code_message(self, content: str, prefix=None, suffix=None): - queue.add( + MainQueue.instance.add( QueueItem( "add_code_message", super().add_code_message, diff --git a/src/pyramid/services/logger.py b/src/pyramid/services/logger.py index 2a40695..861997c 100644 --- a/src/pyramid/services/logger.py +++ b/src/pyramid/services/logger.py @@ -39,6 +39,9 @@ def start(self): self.__enable_log_to_file() self.__enable_log_to_file_exceptions() + self.logger.info("────────────────────────────────────────────") + self.logger.info(self.__information_service.get()) + # Deletion of log files over 10 utils.keep_latest_files(self.__logs_dir, 10, "error") diff --git a/src/pyramid/tools/custom_queue.py b/src/pyramid/tools/custom_queue.py index fbad8a7..70e435e 100644 --- a/src/pyramid/tools/custom_queue.py +++ b/src/pyramid/tools/custom_queue.py @@ -86,7 +86,6 @@ def run_task(func: Callable, loop: asyncio.AbstractEventLoop | None, **kwargs): result = func(**kwargs) return result - class Queue: all_queue = deque() @@ -97,10 +96,12 @@ def __init__(self, threads=1, name=None): self.__threads_list: List[Thread] = [] self.__lock = Lock() self.__worker = worker + self.__name = name + def create_threads(self): for thread_id in range(1, self.__threads + 1): thread = Thread( - name=f"{name} n°{thread_id}", + name="%s n°%d{thread_id}" % (self.__name, thread_id), target=self.__worker, args=(self.__queue, thread_id, self.__lock, self.__event), daemon=True, diff --git a/src/pyramid/tools/main_queue.py b/src/pyramid/tools/main_queue.py new file mode 100644 index 0000000..9e95c14 --- /dev/null +++ b/src/pyramid/tools/main_queue.py @@ -0,0 +1,13 @@ +from pyramid.tools.custom_queue import Queue + + +class MainQueue: + instance: Queue + + @classmethod + def init(cls): + cls.instance = Queue(1, "MessageSender") + cls.instance.create_threads() + cls.instance.start() + cls.instance.register_to_wait_on_exit() + diff --git a/tests/queue_test.py b/tests/queue_test.py index 9b6db50..a78ea82 100644 --- a/tests/queue_test.py +++ b/tests/queue_test.py @@ -7,6 +7,7 @@ class SimpleQueue(unittest.TestCase): def test_add(self): queue = Queue(threads=1) + queue.create_threads() self.assertEqual(queue.length(), 0) item = QueueItem(name="test", func=lambda x: x, x=5) @@ -15,6 +16,7 @@ def test_add(self): def test_add_at_start(self): queue = Queue(threads=1) + queue.create_threads() self.assertEqual(queue.length(), 0) item = QueueItem(name="test", func=lambda x: x, x=5) @@ -23,6 +25,7 @@ def test_add_at_start(self): def test_worker_start_before(self): queue = Queue(threads=1) + queue.create_threads() self.assertEqual(queue.length(), 0) queue.start() @@ -36,6 +39,7 @@ def test_worker_start_before(self): def test_worker_start_after(self): queue = Queue(threads=1) + queue.create_threads() self.assertEqual(queue.length(), 0) item = QueueItem(name="test", func=lambda x: x, x=5) @@ -49,6 +53,7 @@ def test_worker_start_after(self): def test_wait_for_end(self): queue = Queue(threads=1) + queue.create_threads() queue.register_to_wait_on_exit() queue.start() @@ -62,6 +67,7 @@ class MediumQueue(unittest.TestCase): def test_order_simple(self): thread_nb = 1 queue = Queue(threads=thread_nb) + queue.create_threads() results = [] results_excepted = list(range(1, 10)) @@ -80,6 +86,7 @@ def test_order_simple(self): def test_order_reverse(self): thread_nb = 1 queue = Queue(threads=thread_nb) + queue.create_threads() results = [] results_excepted = list(range(9, 0, -1)) @@ -98,6 +105,7 @@ def test_order_reverse(self): def test_order_mixed(self): thread_nb = 1 queue = Queue(threads=thread_nb) + queue.create_threads() results = [] results_excepted = list(range(1, 100)) @@ -154,6 +162,7 @@ def test_wait_for_end_shutdown_threads(self): items = 100 queue = Queue(threads=thread_nb) + queue.create_threads() queue.register_to_wait_on_exit() for i in range(items): From 93bbb7bee0c8eca84e0557333cb6734f107c49ce Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Wed, 18 Sep 2024 00:43:16 +0200 Subject: [PATCH 09/32] feat: add task with service injection --- src/pyramid/api/services/discord.py | 7 ++ src/pyramid/api/services/tools/register.py | 69 ++++++++------- src/pyramid/api/tasks/tools/annotation.py | 16 ++++ src/pyramid/api/tasks/tools/injector.py | 16 ++++ src/pyramid/api/tasks/tools/parameters.py | 12 +++ src/pyramid/api/tasks/tools/register.py | 84 +++++++++++++++++++ .../discord/commands/tools/register.py | 16 ++-- src/pyramid/data/functional/main.py | 48 +---------- src/pyramid/services/discord.py | 44 +--------- src/pyramid/tasks/discord.py | 22 +++++ 10 files changed, 205 insertions(+), 129 deletions(-) create mode 100644 src/pyramid/api/services/discord.py create mode 100644 src/pyramid/api/tasks/tools/annotation.py create mode 100644 src/pyramid/api/tasks/tools/injector.py create mode 100644 src/pyramid/api/tasks/tools/parameters.py create mode 100644 src/pyramid/api/tasks/tools/register.py create mode 100644 src/pyramid/tasks/discord.py diff --git a/src/pyramid/api/services/discord.py b/src/pyramid/api/services/discord.py new file mode 100644 index 0000000..8b39261 --- /dev/null +++ b/src/pyramid/api/services/discord.py @@ -0,0 +1,7 @@ +from abc import ABC, abstractmethod +from pyramid.connector.discord.bot import DiscordBot + +class IDiscordService(ABC): + + def __init__(self): + self.discord_bot: DiscordBot diff --git a/src/pyramid/api/services/tools/register.py b/src/pyramid/api/services/tools/register.py index cf909ba..1b40368 100644 --- a/src/pyramid/api/services/tools/register.py +++ b/src/pyramid/api/services/tools/register.py @@ -14,43 +14,42 @@ class ServiceRegister: __SERVICE_TO_REGISTER: dict[str, type[ServiceInjector]] = {} __SERVICE_REGISTERED: dict[str, ServiceInjector] = {} - @staticmethod - def register_service(name: str, type: type[object]): + @classmethod + def register_service(cls, name: str, type: type[object]): if not issubclass(type, ServiceInjector): raise TypeError("Service %s is not a subclass of ServiceInjector and cannot be initialized." % name) - if name in ServiceRegister.__SERVICE_TO_REGISTER: - already_class_name = ServiceRegister.__SERVICE_TO_REGISTER[name].__name__ + if name in cls.__SERVICE_TO_REGISTER: + already_class_name = cls.__SERVICE_TO_REGISTER[name].__name__ raise ServiceAlreadyRegisterException( "Cannot register %s with %s, it is already registered with the class %s." % (name, type.__name__, already_class_name) ) - ServiceRegister.__SERVICE_TO_REGISTER[name] = type + cls.__SERVICE_TO_REGISTER[name] = type - @staticmethod - def import_services(): + @classmethod + def import_services(cls): package_name = "pyramid.services" package = importlib.import_module(package_name) for loader, module_name, is_pkg in pkgutil.iter_modules(package.__path__): full_module_name = f"{package_name}.{module_name}" - module = importlib.import_module(full_module_name) + importlib.import_module(full_module_name) - @staticmethod - def create_services(): - for name, cls in ServiceRegister.__SERVICE_TO_REGISTER.items(): - class_instance = cls() - ServiceRegister.__SERVICE_REGISTERED[name] = class_instance + @classmethod + def create_services(cls): + for name, service_type in cls.__SERVICE_TO_REGISTER.items(): + class_instance = service_type() + cls.__SERVICE_REGISTERED[name] = class_instance - @staticmethod - def inject_services(): + @classmethod + def inject_services(cls): # Step 1: Create a graph of dependencies dependency_graph = defaultdict(list) indegree = defaultdict(int) # To track the number of dependencies # Create instances but delay injecting dependencies - for name, cls in ServiceRegister.__SERVICE_TO_REGISTER.items(): - class_instance = cls() - ServiceRegister.__SERVICE_REGISTERED[name] = class_instance + for name, service_type in cls.__SERVICE_TO_REGISTER.items(): + class_instance = cls.__SERVICE_REGISTERED[name] # Step 2: Parse dependencies for each service signature = inspect.signature(class_instance.injectService) @@ -58,7 +57,7 @@ def inject_services(): for method_parameter in method_parameters: dependency_name = method_parameter.annotation.__name__ - if dependency_name not in ServiceRegister.__SERVICE_REGISTERED: + if dependency_name not in cls.__SERVICE_REGISTERED: raise ServiceAlreadyNotRegisterException( "Cannot register %s as a dependency for %s because the dependency is not registered." % (dependency_name, name) @@ -69,7 +68,7 @@ def inject_services(): # Step 3: Perform a topological sort to determine the order of instantiation sorted_services = [] - queue = deque([service for service in ServiceRegister.__SERVICE_TO_REGISTER if indegree[service] == 0]) + queue = deque([service for service in cls.__SERVICE_TO_REGISTER if indegree[service] == 0]) while queue: service = queue.popleft() @@ -80,8 +79,8 @@ def inject_services(): if indegree[dependent] == 0: queue.append(dependent) - if len(sorted_services) != len(ServiceRegister.__SERVICE_TO_REGISTER): - unresolved_services = set(ServiceRegister.__SERVICE_TO_REGISTER) - set(sorted_services) + if len(sorted_services) != len(cls.__SERVICE_TO_REGISTER): + unresolved_services = set(cls.__SERVICE_TO_REGISTER) - set(sorted_services) raise ServiceCicularDependencyException( "Circular dependency detected! The following services are involved in a circular dependency: %s" % ', '.join(unresolved_services) @@ -89,24 +88,23 @@ def inject_services(): # Step 4: Inject dependencies in the correct order for service_name in sorted_services: - class_instance = ServiceRegister.__SERVICE_REGISTERED[service_name] + class_instance = cls.__SERVICE_REGISTERED[service_name] signature = inspect.signature(class_instance.injectService) method_parameters = list(signature.parameters.values()) services_dependencies = [] for method_parameter in method_parameters: dependency_name = method_parameter.annotation.__name__ - dependency_instance = ServiceRegister.__SERVICE_REGISTERED[dependency_name] + dependency_instance = cls.__SERVICE_REGISTERED[dependency_name] services_dependencies.append(dependency_instance) class_instance.injectService(*services_dependencies) - @staticmethod - def get_dependency_tree(): + @classmethod + def get_dependency_tree(cls): # Step 1: Build dependency graph dependency_graph = defaultdict(list) - for name, cls in ServiceRegister.__SERVICE_TO_REGISTER.items(): - class_instance = ServiceRegister.__SERVICE_REGISTERED[name] + for name, class_instance in cls.__SERVICE_REGISTERED.items(): signature = inspect.signature(class_instance.injectService) method_parameters = list(signature.parameters.values()) @@ -133,7 +131,7 @@ def build_tree(node, prefix="", last=True): build_tree(child, prefix, i == len(children) - 1) # Step 4: Find root services (those with no dependencies) - all_services = set(ServiceRegister.__SERVICE_TO_REGISTER.keys()) + all_services = set(cls.__SERVICE_TO_REGISTER.keys()) dependent_services = set(dep for deps in dependency_graph.values() for dep in deps) root_services = all_services - dependent_services @@ -146,13 +144,12 @@ def build_tree(node, prefix="", last=True): return "Services tree :\n" + "\n".join(buffer) - - @staticmethod - def start_services(): - for name, class_instance in ServiceRegister.__SERVICE_REGISTERED.items(): + @classmethod + def start_services(cls): + for name, class_instance in cls.__SERVICE_REGISTERED.items(): class_instance.start() - @staticmethod - def get_service(class_type: Type[T]) -> T: + @classmethod + def get_service(cls, class_type: Type[T]) -> T: class_name = class_type.__name__ - return ServiceRegister.__SERVICE_REGISTERED[class_name] + return cls.__SERVICE_REGISTERED[class_name] diff --git a/src/pyramid/api/tasks/tools/annotation.py b/src/pyramid/api/tasks/tools/annotation.py new file mode 100644 index 0000000..9e5c7f3 --- /dev/null +++ b/src/pyramid/api/tasks/tools/annotation.py @@ -0,0 +1,16 @@ +from typing import Optional +from pyramid.api.services.tools.register import ServiceRegister +from pyramid.api.tasks.tools.injector import TaskInjector +from pyramid.api.tasks.tools.parameters import ParametersTask +from pyramid.api.tasks.tools.register import TaskRegister + + +def pyramid_task(*, parameters: ParametersTask): + def decorator(cls): + class_name = cls.__name__ + if not issubclass(cls, TaskInjector): + raise TypeError("Class %s must inherit from TaskInjector" % class_name) + + TaskRegister.register_tasks(cls, parameters) + return cls + return decorator diff --git a/src/pyramid/api/tasks/tools/injector.py b/src/pyramid/api/tasks/tools/injector.py new file mode 100644 index 0000000..8aad739 --- /dev/null +++ b/src/pyramid/api/tasks/tools/injector.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod +from discord.ext.commands import Bot + +class TaskInjector(ABC): + + def injectService(self): + pass + + async def worker_asyc(self): + pass + + async def stop_asyc(self): + pass + + # def worker(self): + # pass diff --git a/src/pyramid/api/tasks/tools/parameters.py b/src/pyramid/api/tasks/tools/parameters.py new file mode 100644 index 0000000..a912a0c --- /dev/null +++ b/src/pyramid/api/tasks/tools/parameters.py @@ -0,0 +1,12 @@ +import asyncio +from threading import Thread + +from pyramid.api.tasks.tools.injector import TaskInjector + + +class ParametersTask: + + def __init__(self): + self.loop: asyncio.AbstractEventLoop + self.thread: Thread + self.task_cls: TaskInjector diff --git a/src/pyramid/api/tasks/tools/register.py b/src/pyramid/api/tasks/tools/register.py new file mode 100644 index 0000000..44d6bd3 --- /dev/null +++ b/src/pyramid/api/tasks/tools/register.py @@ -0,0 +1,84 @@ + +import asyncio +import importlib +import inspect +import logging +import pkgutil +import signal +from threading import Thread +from typing import TypeVar +from pyramid.api.services.tools.register import ServiceRegister +from pyramid.api.tasks.tools.injector import TaskInjector +from pyramid.api.tasks.tools.parameters import ParametersTask + +T = TypeVar('T') + +class TaskRegister: + + __TASKS_REGISTERED: dict[str, ParametersTask] = {} + + @classmethod + def import_tasks(cls): + package_name = "pyramid.tasks" + package = importlib.import_module(package_name) + + for loader, module_name, is_pkg in pkgutil.iter_modules(package.__path__): + full_module_name = f"{package_name}.{module_name}" + importlib.import_module(full_module_name) + + @classmethod + def register_tasks(cls, type: type[object], parameters: ParametersTask): + if not issubclass(type, TaskInjector): + raise TypeError("Service %s is not a subclass of TaskInjector and cannot be initialized." % type.__name__) + parameters.task_cls = type() + cls.__TASKS_REGISTERED[type.__name__] = parameters + + @classmethod + def inject_tasks(cls): + for name, parameters in cls.__TASKS_REGISTERED.items(): + signature = inspect.signature(parameters.task_cls.injectService) + method_parameters = list(signature.parameters.values()) + + services_dependencies = [] + for method_parameter in method_parameters: + dependency_cls = method_parameter.annotation + dependency_instance = ServiceRegister.get_service(dependency_cls) + services_dependencies.append(dependency_instance) + + parameters.task_cls.injectService(*services_dependencies) + + @classmethod + def __handle_signal(cls, signum: int, frame): + logging.info("Received signal %d. shutting down ..." % signum) + + for name, parameters in cls.__TASKS_REGISTERED.items(): + async def shutdown(loop: asyncio.AbstractEventLoop): + await parameters.task_cls.stop_asyc() + parameters.loop.stop() + asyncio.run_coroutine_threadsafe(shutdown(parameters.loop), parameters.loop) + + @classmethod + def start_tasks(cls): + previous_handler = signal.signal(signal.SIGTERM, cls.__handle_signal) + for name, parameters in cls.__TASKS_REGISTERED.items(): + + parameters.loop = asyncio.new_event_loop() + + def running(loop: asyncio.AbstractEventLoop): + asyncio.set_event_loop(loop) + + loop.create_task(parameters.task_cls.worker_asyc()) + try: + loop.run_forever() + finally: + loop.close() + + parameters.thread = Thread(name=name, target=running, args=(parameters.loop,)) + + for name, parameters in cls.__TASKS_REGISTERED.items(): + parameters.thread.start() + + for name, parameters in cls.__TASKS_REGISTERED.items(): + parameters.thread.join() + + signal.signal(signal.SIGTERM, previous_handler) diff --git a/src/pyramid/connector/discord/commands/tools/register.py b/src/pyramid/connector/discord/commands/tools/register.py index bea32b3..520433b 100644 --- a/src/pyramid/connector/discord/commands/tools/register.py +++ b/src/pyramid/connector/discord/commands/tools/register.py @@ -10,22 +10,22 @@ class CommandRegister: __COMMANDS_TO_REGISTER: dict[type[AbstractCommand], ParametersCommand] = {} - @staticmethod - def register_command(type: type[AbstractCommand], parameterCommand: ParametersCommand): + @classmethod + def register_command(cls, type: type[AbstractCommand], parameterCommand: ParametersCommand): CommandRegister.__COMMANDS_TO_REGISTER[type] = parameterCommand - @staticmethod - def import_commands(): + @classmethod + def import_commands(cls): package_name = "pyramid.connector.discord.commands" package = importlib.import_module(package_name) for loader, module_name, is_pkg in pkgutil.iter_modules(package.__path__): full_module_name = f"{package_name}.{module_name}" - module = importlib.import_module(full_module_name) + importlib.import_module(full_module_name) - @staticmethod - def create_commands(services: dict[str, object], bot: Bot, logger: logging.Logger, command_prefix: str | None = None): - for cls, parameters in CommandRegister.__COMMANDS_TO_REGISTER.items(): + @classmethod + def create_commands(cls, services: dict[str, object], bot: Bot, logger: logging.Logger, command_prefix: str | None = None): + for cls, parameters in cls.__COMMANDS_TO_REGISTER.items(): class_instance = cls(parameters, bot, logger) class_instance.register(command_prefix) signature = inspect.signature(class_instance.injectService) diff --git a/src/pyramid/data/functional/main.py b/src/pyramid/data/functional/main.py index f816b62..2a7f694 100644 --- a/src/pyramid/data/functional/main.py +++ b/src/pyramid/data/functional/main.py @@ -9,6 +9,7 @@ from pyramid.api.services.information import IInformationService from pyramid.api.services.logger import ILoggerService from pyramid.api.services.tools.register import ServiceRegister +from pyramid.api.tasks.tools.register import TaskRegister from pyramid.data.functional.application_info import ApplicationInfo from pyramid.connector.discord.bot import DiscordBot from pyramid.tools.custom_queue import Queue @@ -39,55 +40,14 @@ def start(self): ServiceRegister.start_services() logger = ServiceRegister.get_service(ILoggerService) - info = ServiceRegister.get_service(IInformationService) - config = ServiceRegister.get_service(IConfigurationService) logger.debug(ServiceRegister.get_dependency_tree()) MainQueue.init() - # Discord Bot Instance - discord_bot = DiscordBot(logger.getChild("Discord"), info.get(), config) - # Create bot - discord_bot.create() - - loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() - - def running(loop: asyncio.AbstractEventLoop): - asyncio.set_event_loop(loop) - - # Run the asynchronous function in the thread without blocking it - loop.create_task(discord_bot.start()) - try: - # Run tasks in thread infinitly - loop.run_forever() - finally: - loop.close() - - - async def shutdown(loop: asyncio.AbstractEventLoop): - await discord_bot.stop() - loop.stop() - - def handle_signal(signum: int, frame): - logging.info(f"Received signal {signum}. shutting down ...") - asyncio.run_coroutine_threadsafe(shutdown(loop), loop) - - previous_handler = signal.signal(signal.SIGTERM, handle_signal) - - # Connect bot to Discord servers in his own thread - thread = Thread(name="Discord", target=running, args=(loop,)) - thread.start() - thread.join() - - signal.signal(signal.SIGTERM, previous_handler) - - # def running_async(discord_bot: DiscordBot): - # asyncio.run(discord_bot.start()) - - # thread = Thread(name="Discord", target=running_async, args=(discord_bot,)) - # thread.start() - # thread.join() + TaskRegister.import_tasks() + TaskRegister.inject_tasks() + TaskRegister.start_tasks() def stop(self): diff --git a/src/pyramid/services/discord.py b/src/pyramid/services/discord.py index 8092166..b102ea2 100644 --- a/src/pyramid/services/discord.py +++ b/src/pyramid/services/discord.py @@ -1,19 +1,14 @@ -import asyncio -import logging -import signal -from threading import Thread - from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.discord import IDiscordService from pyramid.api.services.information import IInformationService from pyramid.api.services.logger import ILoggerService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector from pyramid.connector.discord.bot import DiscordBot -from pyramid.tools.custom_queue import Queue -# @pyramid_service() -class DiscordBotService(ServiceInjector): +@pyramid_service(interface=IDiscordService) +class DiscordBotService(IDiscordService, ServiceInjector): def injectService(self, logger_service: ILoggerService, @@ -27,36 +22,3 @@ def injectService(self, def start(self): self.discord_bot = DiscordBot(self.__logger_service.getChild("Discord"), self.__information_service.get(), self.__configuration_service) self.discord_bot.create() - - loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() - - def running(loop: asyncio.AbstractEventLoop): - asyncio.set_event_loop(loop) - - loop.create_task(self.discord_bot.start()) - try: - loop.run_forever() - finally: - loop.close() - - async def shutdown(loop: asyncio.AbstractEventLoop): - await self.discord_bot.stop() - loop.stop() - - def handle_signal(signum: int, frame): - logging.info("Received signal %d. shutting down ..." % signum) - asyncio.run_coroutine_threadsafe(shutdown(loop), loop) - - previous_handler = signal.signal(signal.SIGTERM, handle_signal) - - # Connect bot to Discord servers in his own thread - thread = Thread(name="Discord", target=running, args=(loop,)) - thread.start() - thread.join() - - signal.signal(signal.SIGTERM, previous_handler) - - def stop(self): - logging.info("Wait for background tasks to stop") - Queue.wait_for_end(5) - logging.info("Bye bye") diff --git a/src/pyramid/tasks/discord.py b/src/pyramid/tasks/discord.py new file mode 100644 index 0000000..27f3ebd --- /dev/null +++ b/src/pyramid/tasks/discord.py @@ -0,0 +1,22 @@ +from pyramid.api.services.discord import IDiscordService +from pyramid.api.tasks.tools.annotation import pyramid_task +from pyramid.api.tasks.tools.injector import TaskInjector +from pyramid.api.tasks.tools.parameters import ParametersTask + +@pyramid_task(parameters=ParametersTask()) +class DiscordTask(TaskInjector): + + def __init__(self): + pass + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def worker_asyc(self): + self.discord_bot = self.__discord_service.discord_bot + await self.discord_bot.start() + + async def stop_asyc(self): + await self.discord_bot.stop() From 57cccf57c71ce12fd1f506e62e7fa4253e616461 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Wed, 18 Sep 2024 01:21:29 +0200 Subject: [PATCH 10/32] feat: create service module --- src/pyramid/api/services/__init__.py | 5 +++++ src/pyramid/api/services/tools/register.py | 18 +++++++++--------- src/pyramid/data/functional/main.py | 9 ++++----- src/pyramid/services/configuration.py | 3 +-- src/pyramid/services/discord.py | 5 +---- src/pyramid/services/information.py | 2 +- src/pyramid/services/logger.py | 4 +--- src/pyramid/services/logger_level.py | 3 +-- src/pyramid/services/socket_server.py | 3 +-- 9 files changed, 24 insertions(+), 28 deletions(-) create mode 100644 src/pyramid/api/services/__init__.py diff --git a/src/pyramid/api/services/__init__.py b/src/pyramid/api/services/__init__.py new file mode 100644 index 0000000..321640f --- /dev/null +++ b/src/pyramid/api/services/__init__.py @@ -0,0 +1,5 @@ +from .configuration import IConfigurationService +from .discord import IDiscordService +from .information import IInformationService +from .logger import ILoggerService +from .socket_server import ISocketServerService diff --git a/src/pyramid/api/services/tools/register.py b/src/pyramid/api/services/tools/register.py index 1b40368..44a06de 100644 --- a/src/pyramid/api/services/tools/register.py +++ b/src/pyramid/api/services/tools/register.py @@ -14,6 +14,15 @@ class ServiceRegister: __SERVICE_TO_REGISTER: dict[str, type[ServiceInjector]] = {} __SERVICE_REGISTERED: dict[str, ServiceInjector] = {} + @classmethod + def import_services(cls): + package_name = "pyramid.services" + package = importlib.import_module(package_name) + + for loader, module_name, is_pkg in pkgutil.iter_modules(package.__path__): + full_module_name = f"{package_name}.{module_name}" + importlib.import_module(full_module_name) + @classmethod def register_service(cls, name: str, type: type[object]): if not issubclass(type, ServiceInjector): @@ -26,15 +35,6 @@ def register_service(cls, name: str, type: type[object]): ) cls.__SERVICE_TO_REGISTER[name] = type - @classmethod - def import_services(cls): - package_name = "pyramid.services" - package = importlib.import_module(package_name) - - for loader, module_name, is_pkg in pkgutil.iter_modules(package.__path__): - full_module_name = f"{package_name}.{module_name}" - importlib.import_module(full_module_name) - @classmethod def create_services(cls): for name, service_type in cls.__SERVICE_TO_REGISTER.items(): diff --git a/src/pyramid/data/functional/main.py b/src/pyramid/data/functional/main.py index 2a7f694..9421d8b 100644 --- a/src/pyramid/data/functional/main.py +++ b/src/pyramid/data/functional/main.py @@ -17,20 +17,19 @@ class Main: - def __init__(self): + # def __init__(self): # Program information - self._info = ApplicationInfo() # self._health = HealthModules() - self._discord_bot = None + # self._discord_bot = None - # Argument management def args(self): parser = argparse.ArgumentParser(description="Music Bot Discord using Deezer.") parser.add_argument("--version", action="store_true", help="Print version", required=False) args = parser.parse_args() if args.version: - print(f"{self._info.get_version()}") + info = ApplicationInfo() + print(info.get_version()) sys.exit(0) def start(self): diff --git a/src/pyramid/services/configuration.py b/src/pyramid/services/configuration.py index 75b3205..47413fa 100644 --- a/src/pyramid/services/configuration.py +++ b/src/pyramid/services/configuration.py @@ -1,5 +1,4 @@ -from pyramid.api.services.configuration import IConfigurationService -from pyramid.api.services.logger import ILoggerService +from pyramid.api.services import IConfigurationService, ILoggerService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector from pyramid.tools import utils diff --git a/src/pyramid/services/discord.py b/src/pyramid/services/discord.py index b102ea2..21838e9 100644 --- a/src/pyramid/services/discord.py +++ b/src/pyramid/services/discord.py @@ -1,7 +1,4 @@ -from pyramid.api.services.configuration import IConfigurationService -from pyramid.api.services.discord import IDiscordService -from pyramid.api.services.information import IInformationService -from pyramid.api.services.logger import ILoggerService +from pyramid.api.services import IConfigurationService, IDiscordService, IInformationService, ILoggerService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector from pyramid.connector.discord.bot import DiscordBot diff --git a/src/pyramid/services/information.py b/src/pyramid/services/information.py index ddc2d0c..2f6a6f5 100644 --- a/src/pyramid/services/information.py +++ b/src/pyramid/services/information.py @@ -1,5 +1,5 @@ -from pyramid.api.services.information import IInformationService +from pyramid.api.services import IInformationService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector from pyramid.data.functional.application_info import ApplicationInfo diff --git a/src/pyramid/services/logger.py b/src/pyramid/services/logger.py index 861997c..adcded4 100644 --- a/src/pyramid/services/logger.py +++ b/src/pyramid/services/logger.py @@ -5,11 +5,9 @@ import coloredlogs from datetime import datetime -from pyramid.api.services.information import IInformationService +from pyramid.api.services import IInformationService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector -from pyramid.data.environment import Environment -from pyramid.api.services.configuration import IConfigurationService from pyramid.api.services.logger import ILoggerService from pyramid.tools import utils diff --git a/src/pyramid/services/logger_level.py b/src/pyramid/services/logger_level.py index a31353c..698c850 100644 --- a/src/pyramid/services/logger_level.py +++ b/src/pyramid/services/logger_level.py @@ -4,8 +4,7 @@ from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector from pyramid.data.environment import Environment -from pyramid.api.services.configuration import IConfigurationService -from pyramid.api.services.logger import ILoggerService +from pyramid.api.services import IConfigurationService, ILoggerService @pyramid_service() class LoggerLevelService(ServiceInjector): diff --git a/src/pyramid/services/socket_server.py b/src/pyramid/services/socket_server.py index 46f9f20..5adf270 100644 --- a/src/pyramid/services/socket_server.py +++ b/src/pyramid/services/socket_server.py @@ -1,8 +1,7 @@ from threading import Thread -from pyramid.api.services.logger import ILoggerService -from pyramid.api.services.socket_server import ISocketServerService +from pyramid.api.services import ILoggerService, ISocketServerService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector from pyramid.client.server import SocketServer From e00b600fb8875c6fff2904602a6f9a7cdb8438c9 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Fri, 20 Sep 2024 21:32:00 +0200 Subject: [PATCH 11/32] feat: move discord bot as service --- src/pyramid/api/services/discord.py | 19 +- src/pyramid/connector/discord/bot.py | 173 ------------------ .../discord/commands/tools/register.py | 25 ++- src/pyramid/data/functional/main.py | 5 +- src/pyramid/data/guild_instance.py | 19 ++ src/pyramid/services/discord.py | 136 +++++++++++++- .../discord_commads.py} | 79 ++++---- .../discord_listener.py} | 29 ++- src/pyramid/services/logger_level.py | 6 +- src/pyramid/tasks/discord.py | 5 +- 10 files changed, 247 insertions(+), 249 deletions(-) delete mode 100644 src/pyramid/connector/discord/bot.py create mode 100644 src/pyramid/data/guild_instance.py rename src/pyramid/{connector/discord/bot_cmd.py => services/discord_commads.py} (77%) rename src/pyramid/{connector/discord/bot_listener.py => services/discord_listener.py} (70%) diff --git a/src/pyramid/api/services/discord.py b/src/pyramid/api/services/discord.py index 8b39261..36322ea 100644 --- a/src/pyramid/api/services/discord.py +++ b/src/pyramid/api/services/discord.py @@ -1,7 +1,22 @@ from abc import ABC, abstractmethod -from pyramid.connector.discord.bot import DiscordBot +from discord import Guild +from discord.ext.commands import Bot + +from pyramid.connector.discord.guild_cmd import GuildCmd class IDiscordService(ABC): def __init__(self): - self.discord_bot: DiscordBot + self.bot: Bot + + @abstractmethod + async def connect_bot(self): + pass + + @abstractmethod + async def disconnect_bot(self): + pass + + @abstractmethod + def get_guild_cmd(self, guild: Guild) -> GuildCmd: + pass \ No newline at end of file diff --git a/src/pyramid/connector/discord/bot.py b/src/pyramid/connector/discord/bot.py deleted file mode 100644 index c1d42a2..0000000 --- a/src/pyramid/connector/discord/bot.py +++ /dev/null @@ -1,173 +0,0 @@ -import logging -import sys -import time -import traceback -from logging import Logger -from typing import Dict - -import discord -from discord import ( - Guild, - Interaction, - PrivilegedIntentsRequired, -) -from discord.ext.commands import Bot, Context -from discord.ext.commands.errors import ( - CommandNotFound, - MissingPermissions, - MissingRequiredArgument, - CommandError, -) -from discord.app_commands.errors import AppCommandError, CommandInvokeError -from pyramid.api.services.configuration import IConfigurationService -from pyramid.data.functional.application_info import ApplicationInfo -from pyramid.data.environment import Environment -from pyramid.data.guild_data import GuildData -from pyramid.connector.discord.bot_cmd import BotCmd -from pyramid.connector.discord.bot_listener import BotListener -from pyramid.connector.discord.guild_cmd import GuildCmd -from pyramid.connector.discord.guild_queue import GuildQueue -from pyramid.data.functional.engine_source import EngineSource -from pyramid.data.exceptions import DiscordMessageException -from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued -from pyramid.data.health import HealthModules -from pyramid.connector.discord.music_player_interface import MusicPlayerInterface -from pyramid.tools.configuration.configuration import Configuration - - -class DiscordBot: - def __init__(self, logger: logging.Logger, information: ApplicationInfo, config: IConfigurationService): - self.__logger = logger - self.__information = information - self.__token = config.discord__token - self.__ffmpeg = config.discord__ffmpeg - self.__environment: Environment = config.mode - self.__engine_source = EngineSource(config) - - intents = discord.Intents.default() - # intents.members = True - intents.message_content = True - - bot = Bot( - command_prefix="$$", - intents=intents, - activity=discord.Activity( - type=discord.ActivityType.listening, - name=f"{self.__information.get_version()}", - ), - ) - self.bot = bot - - self.guilds_instances: Dict[int, GuildInstances] = {} - - # def create(self, health: HealthModules): - def create(self): - self.__logger.info("Discord bot creating with discord.py v%s ...", discord.__version__) - self.listeners = BotListener(self.bot, self.__logger) - self.cmd = BotCmd( - self.bot, - self.__get_guild_cmd, - self.__logger, - self.__information, - self.__environment - ) - # self._health = health - - @self.bot.event - async def on_command_error(ctx: Context, error: CommandError): - if isinstance(error, CommandNotFound): - await ctx.send("That command didn't exists !") - return - elif isinstance(error, MissingRequiredArgument): - await ctx.send("Please pass in all requirements.") - return - elif isinstance(error, MissingPermissions): - await ctx.send( - "You dont have all the requirements or permissions for using this command :angry:" - ) - return - self.__logger.error("Command error from on_command_error : %s", error) - - @self.bot.event - async def on_error(event, *args, **kwargs): - # message = args[0] # Message object - # traceback.extract_stack - self.__logger.error("Error from on_error : %s", traceback.format_exc()) - # await bot.send_message(message.channel, "You caused an error!") - - async def on_tree_error(ctx: Interaction, app_error: AppCommandError, /): - ms = MessageSenderQueued(ctx) - - if isinstance(app_error, CommandInvokeError): - msg = ", ".join(app_error.args) - error = app_error.original - else: - msg = "Error from on_tree_error" - error = app_error - trace = "".join(traceback.format_exception(type(error), error, error.__traceback__)) - self.__logger.error("%s :\n%s", msg, trace) - - discord_explanation = ":warning: You caused an error!" - if isinstance(error, DiscordMessageException): - ms.add_message(discord_explanation) - else: - attributes_dict = vars(ctx.namespace) - formatted_attributes = " ".join( - f"{key}: {value}" - for key, value in attributes_dict.items() # TODO Handle ENUM name instead of value - ) - discord_explanation = ( - ":warning: An error occurred while processing the command `/%s%s`" - % ( - ctx.command.name if ctx.command else "", - f" {formatted_attributes}" if formatted_attributes != "" else "", - ) - ) - if self.__environment is not Environment.PRODUCTION: - ms.add_code_message(trace, discord_explanation) - else: - ms.add_message(discord_explanation) - - self.bot.tree.on_error = on_tree_error - - @self.bot.event - async def on_command(ctx: Context): - self.__logger.debug("on_command : %s", ctx.author) - - self.cmd.register() - self.listeners.register() - - async def start(self): - try: - self.__logger.info("Discord bot login") - await self.bot.login(self.__token) - self.__logger.info("Discord bot connecting") - # self._health.discord = True - await self.bot.connect() - except PrivilegedIntentsRequired as ex: - # self._health.discord = False - self.__logger.critical(ex) - sys.exit(1) - - async def stop(self): - # self.bot.clear() - self.__logger.info("Discord bot stop") - await self.bot.close() - self.__logger.info("Discord bot stopped") - - def __get_guild_cmd(self, guild: Guild) -> GuildCmd: - if guild.id not in self.guilds_instances: - self.guilds_instances[guild.id] = GuildInstances( - guild, self.__logger.getChild(guild.name), self.__engine_source, self.__ffmpeg - ) - - return self.guilds_instances[guild.id].cmds - - -class GuildInstances: - def __init__(self, guild: Guild, logger: Logger, engine_source: EngineSource, ffmpeg_path: str): - self.data = GuildData(guild, engine_source) - self.mpi = MusicPlayerInterface(self.data.guild.preferred_locale, self.data.track_list) - self.songs = GuildQueue(self.data, ffmpeg_path, self.mpi) - self.cmds = GuildCmd(logger, self.data, self.songs, engine_source) - self.mpi.set_queue_action(self.cmds) diff --git a/src/pyramid/connector/discord/commands/tools/register.py b/src/pyramid/connector/discord/commands/tools/register.py index 520433b..a316eb7 100644 --- a/src/pyramid/connector/discord/commands/tools/register.py +++ b/src/pyramid/connector/discord/commands/tools/register.py @@ -3,6 +3,7 @@ import logging import pkgutil from discord.ext.commands import Bot +from pyramid.api.services.tools.register import ServiceRegister from pyramid.connector.discord.commands.tools.abc import AbstractCommand from pyramid.connector.discord.commands.tools.parameters import ParametersCommand @@ -24,11 +25,21 @@ def import_commands(cls): importlib.import_module(full_module_name) @classmethod - def create_commands(cls, services: dict[str, object], bot: Bot, logger: logging.Logger, command_prefix: str | None = None): - for cls, parameters in cls.__COMMANDS_TO_REGISTER.items(): - class_instance = cls(parameters, bot, logger) + def create_commands(cls, bot: Bot, logger: logging.Logger, command_prefix: str | None = None): + for type, parameters in cls.__COMMANDS_TO_REGISTER.items(): + class_instance = type(parameters, bot, logger) class_instance.register(command_prefix) - signature = inspect.signature(class_instance.injectService) - parameters = list(signature.parameters.values()) - dependencies = [services[param.annotation.__name__] for param in parameters] - class_instance.injectService(*dependencies) + + @classmethod + def inject_tasks(cls): + for type, parameters in cls.__COMMANDS_TO_REGISTER.items(): + signature = inspect.signature(type.injectService) + method_parameters = list(signature.parameters.values()) + + services_dependencies = [] + for method_parameter in method_parameters: + dependency_cls = method_parameter.annotation + dependency_instance = ServiceRegister.get_service(dependency_cls) + services_dependencies.append(dependency_instance) + + type.injectService(*services_dependencies) diff --git a/src/pyramid/data/functional/main.py b/src/pyramid/data/functional/main.py index 9421d8b..d6af0e0 100644 --- a/src/pyramid/data/functional/main.py +++ b/src/pyramid/data/functional/main.py @@ -5,13 +5,10 @@ import signal from threading import Thread -from pyramid.api.services.configuration import IConfigurationService -from pyramid.api.services.information import IInformationService -from pyramid.api.services.logger import ILoggerService +from pyramid.api.services import ILoggerService from pyramid.api.services.tools.register import ServiceRegister from pyramid.api.tasks.tools.register import TaskRegister from pyramid.data.functional.application_info import ApplicationInfo -from pyramid.connector.discord.bot import DiscordBot from pyramid.tools.custom_queue import Queue from pyramid.tools.main_queue import MainQueue diff --git a/src/pyramid/data/guild_instance.py b/src/pyramid/data/guild_instance.py new file mode 100644 index 0000000..2ba42ba --- /dev/null +++ b/src/pyramid/data/guild_instance.py @@ -0,0 +1,19 @@ + + +from logging import Logger +from discord import Guild + +from pyramid.connector.discord.guild_cmd import GuildCmd +from pyramid.connector.discord.guild_queue import GuildQueue +from pyramid.connector.discord.music_player_interface import MusicPlayerInterface +from pyramid.data.functional.engine_source import EngineSource +from pyramid.data.guild_data import GuildData + + +class GuildInstances: + def __init__(self, guild: Guild, logger: Logger, engine_source: EngineSource, ffmpeg_path: str): + self.data = GuildData(guild, engine_source) + self.mpi = MusicPlayerInterface(self.data.guild.preferred_locale, self.data.track_list) + self.songs = GuildQueue(self.data, ffmpeg_path, self.mpi) + self.cmds = GuildCmd(logger, self.data, self.songs, engine_source) + self.mpi.set_queue_action(self.cmds) diff --git a/src/pyramid/services/discord.py b/src/pyramid/services/discord.py index 21838e9..a7e8109 100644 --- a/src/pyramid/services/discord.py +++ b/src/pyramid/services/discord.py @@ -1,7 +1,32 @@ +import sys +import traceback +from typing import Dict + +import discord +from discord import ( + Guild, + Interaction, + PrivilegedIntentsRequired, +) +from discord.ext.commands import Bot, Context +from discord.ext.commands.errors import ( + CommandNotFound, + MissingPermissions, + MissingRequiredArgument, + CommandError, +) +from discord.app_commands.errors import AppCommandError, CommandInvokeError +from pyramid.api.services.configuration import IConfigurationService +from pyramid.data.environment import Environment +from pyramid.connector.discord.guild_cmd import GuildCmd +from pyramid.data.functional.engine_source import EngineSource +from pyramid.data.exceptions import DiscordMessageException +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued +from pyramid.data.guild_instance import GuildInstances + from pyramid.api.services import IConfigurationService, IDiscordService, IInformationService, ILoggerService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector -from pyramid.connector.discord.bot import DiscordBot @pyramid_service(interface=IDiscordService) @@ -12,10 +37,113 @@ def injectService(self, information_service: IInformationService, configuration_service: IConfigurationService ): - self.__logger_service = logger_service + self.__logger = logger_service self.__information_service = information_service self.__configuration_service = configuration_service def start(self): - self.discord_bot = DiscordBot(self.__logger_service.getChild("Discord"), self.__information_service.get(), self.__configuration_service) - self.discord_bot.create() + self.__engine_source = EngineSource(self.__configuration_service) + + intents = discord.Intents.default() + # intents.members = True + intents.message_content = True + + self.bot = Bot( + command_prefix="$$", + intents=intents, + activity=discord.Activity( + type=discord.ActivityType.listening, + name=self.__information_service.get().get_version(), + ), + ) + + self.guilds_instances: Dict[int, GuildInstances] = {} + self.__logger.info("Discord bot creating with discord.py v%s ...", discord.__version__) + # self._health = health + + @self.bot.event + async def on_command_error(ctx: Context, error: CommandError): + if isinstance(error, CommandNotFound): + await ctx.send("That command didn't exists !") + return + elif isinstance(error, MissingRequiredArgument): + await ctx.send("Please pass in all requirements.") + return + elif isinstance(error, MissingPermissions): + await ctx.send( + "You dont have all the requirements or permissions for using this command :angry:" + ) + return + self.__logger.error("Command error from on_command_error : %s", error) + + @self.bot.event + async def on_error(event, *args, **kwargs): + # message = args[0] # Message object + # traceback.extract_stack + self.__logger.error("Error from on_error : %s", traceback.format_exc()) + # await bot.send_message(message.channel, "You caused an error!") + + async def on_tree_error(ctx: Interaction, app_error: AppCommandError, /): + ms = MessageSenderQueued(ctx) + + if isinstance(app_error, CommandInvokeError): + msg = ", ".join(app_error.args) + error = app_error.original + else: + msg = "Error from on_tree_error" + error = app_error + trace = "".join(traceback.format_exception(type(error), error, error.__traceback__)) + self.__logger.error("%s :\n%s", msg, trace) + + discord_explanation = ":warning: You caused an error!" + if isinstance(error, DiscordMessageException): + ms.add_message(discord_explanation) + else: + attributes_dict = vars(ctx.namespace) + formatted_attributes = " ".join( + f"{key}: {value}" + for key, value in attributes_dict.items() # TODO Handle ENUM name instead of value + ) + discord_explanation = ( + ":warning: An error occurred while processing the command `/%s%s`" + % ( + ctx.command.name if ctx.command else "", + f" {formatted_attributes}" if formatted_attributes != "" else "", + ) + ) + if self.__configuration_service.mode is not Environment.PRODUCTION: + ms.add_code_message(trace, discord_explanation) + else: + ms.add_message(discord_explanation) + + self.bot.tree.on_error = on_tree_error + + @self.bot.event + async def on_command(ctx: Context): + self.__logger.debug("on_command : %s", ctx.author) + + async def connect_bot(self): + try: + self.__logger.info("Discord bot login") + await self.bot.login(self.__configuration_service.discord__token) + self.__logger.info("Discord bot connecting") + # self._health.discord = True + await self.bot.connect() + except PrivilegedIntentsRequired as ex: + # self._health.discord = False + self.__logger.critical(ex) + sys.exit(1) + + async def disconnect_bot(self): + # self.bot.clear() + self.__logger.info("Discord bot stop") + await self.bot.close() + self.__logger.info("Discord bot stopped") + + def get_guild_cmd(self, guild: Guild) -> GuildCmd: + if guild.id not in self.guilds_instances: + self.guilds_instances[guild.id] = GuildInstances( + guild, self.__logger.getChild(guild.name), self.__engine_source, self.__configuration_service.discord__ffmpeg + ) + + return self.guilds_instances[guild.id].cmds diff --git a/src/pyramid/connector/discord/bot_cmd.py b/src/pyramid/services/discord_commads.py similarity index 77% rename from src/pyramid/connector/discord/bot_cmd.py rename to src/pyramid/services/discord_commads.py index a752a0c..5efba02 100644 --- a/src/pyramid/connector/discord/bot_cmd.py +++ b/src/pyramid/services/discord_commads.py @@ -3,6 +3,12 @@ from discord import Guild, Interaction from discord.ext.commands import Bot +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.discord import IDiscordService +from pyramid.api.services.information import IInformationService +from pyramid.api.services.logger import ILoggerService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector from pyramid.connector.discord.commands.tools.register import CommandRegister from pyramid.connector.discord.guild_cmd import GuildCmd from pyramid.data.environment import Environment @@ -11,36 +17,23 @@ from pyramid.data.functional.engine_source import SourceType -class BotCmd: - def __init__( - self, - bot: Bot, - get_guild_cmd: Callable[[Guild], GuildCmd], - logger: Logger, - info: ApplicationInfo, - environment: Environment - ): - self.__bot = bot - self.__get_guild_cmd = get_guild_cmd - self.__logger = logger - self.__info = info - self.__environment = environment - - def register(self): - bot = self.__bot - - services: dict[str, object] = dict() - - service = self.__environment - service_name = service.__class__.__name__ - services[service_name] = service - - service = self.__info - service_name = service.__class__.__name__ - services[service_name] = service - +@pyramid_service() +class DiscordCommands(ServiceInjector): + + def injectService(self, + logger_service: ILoggerService, + configuration_service: IConfigurationService, + discord_service: IDiscordService + ): + self.__logger = logger_service + self.__configuration_service = configuration_service + self.__discord_service = discord_service + + def start(self): + bot = self.__discord_service.bot + CommandRegister.import_commands() - CommandRegister.create_commands(services, self.__bot, self.__logger, self.__environment.name.lower()) + CommandRegister.create_commands(bot, self.__logger.getChild("Discord"), self.__configuration_service.mode.name.lower()) # ping = PingCommand(ParametersCommand("ping"), self.__bot, self.__logger) # ping.register(self.__environment.name.lower()) @@ -58,7 +51,7 @@ async def cmd_play(ctx: Interaction, input: str, engine: SourceType | None): ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) await guild_cmd.play(ms, ctx, input, engine) @@ -69,7 +62,7 @@ async def cmd_play_next(ctx: Interaction, input: str, engine: SourceType | None) ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) await guild_cmd.play(ms, ctx, input, engine, at_end=False) @@ -80,7 +73,7 @@ async def cmd_pause(ctx: Interaction): ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) await guild_cmd.pause(ms, ctx) @@ -91,7 +84,7 @@ async def cmd_resume(ctx: Interaction): ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) await guild_cmd.resume(ms, ctx) @@ -102,7 +95,7 @@ async def cmd_stop(ctx: Interaction): ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) await guild_cmd.stop(ms, ctx) @@ -113,7 +106,7 @@ async def cmd_next(ctx: Interaction): ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) await guild_cmd.next(ms, ctx) @@ -124,7 +117,7 @@ async def cmd_shuffle(ctx: Interaction): ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) await guild_cmd.shuffle(ms, ctx) @@ -135,7 +128,7 @@ async def cmd_remove(ctx: Interaction, number_in_queue: int): ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) await guild_cmd.remove(ms, ctx, number_in_queue) @@ -146,7 +139,7 @@ async def cmd_goto(ctx: Interaction, number_in_queue: int): ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) await guild_cmd.goto(ms, ctx, number_in_queue) @@ -157,7 +150,7 @@ async def cmd_queue(ctx: Interaction): ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) guild_cmd.queue_list(ms, ctx) @@ -168,7 +161,7 @@ async def cmd_queue(ctx: Interaction): # ms = MessageSenderQueued(ctx) # await ms.thinking() # guild: Guild = ctx.guild # type: ignore - # guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + # guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) # await guild_cmd.searchV1(ms, input, engine) @@ -179,7 +172,7 @@ async def cmd_search(ctx: Interaction, input: str, engine: SourceType | None): ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) await guild_cmd.search(ms, input, engine) @@ -192,7 +185,7 @@ async def cmd_play_url(ctx: Interaction, url: str): ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) await guild_cmd.play_url(ms, ctx, url) @@ -206,7 +199,7 @@ async def cmd_play_url_next(ctx: Interaction, url: str): ms = MessageSenderQueued(ctx) await ms.thinking() guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__get_guild_cmd(guild) + guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) await guild_cmd.play_url(ms, ctx, url, at_end=False) diff --git a/src/pyramid/connector/discord/bot_listener.py b/src/pyramid/services/discord_listener.py similarity index 70% rename from src/pyramid/connector/discord/bot_listener.py rename to src/pyramid/services/discord_listener.py index b17a894..bf8e2f6 100644 --- a/src/pyramid/connector/discord/bot_listener.py +++ b/src/pyramid/services/discord_listener.py @@ -1,18 +1,25 @@ -from logging import Logger from discord import ( Guild, Role, ) -from discord.ext.commands import Bot +from pyramid.api.services import ILoggerService, IDiscordService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector -class BotListener: - def __init__(self, bot: Bot, logger: Logger): - self.__bot = bot - self.__logger = logger - def register(self): - bot = self.__bot +@pyramid_service() +class DiscordListeners(ServiceInjector): + + def injectService(self, + logger_service: ILoggerService, + discord_service: IDiscordService + ): + self.__logger = logger_service + self.__discord_service = discord_service + + def start(self): + bot = self.__discord_service.bot @bot.event async def on_ready(): # TODO changed to -> await bot.setup_hook() @@ -41,6 +48,8 @@ async def on_guild_remove(guild: Guild): self.show_info_guild(guild) def show_info_guild(self, guild: Guild): + bot = self.__discord_service.bot + self.__logger.info( "'%s' has %d members. Creator is %s.", guild.name, @@ -48,10 +57,10 @@ def show_info_guild(self, guild: Guild): guild.owner, ) - if self.__bot.user is None: + if bot.user is None: self.__logger.warning("Enable to get discord bot - Unable to get his roles") else: - bot_member = guild.get_member(self.__bot.user.id) + bot_member = guild.get_member(bot.user.id) if bot_member is None: self.__logger.warning(" Enable to get discord bot role on %s", guild.name) else: diff --git a/src/pyramid/services/logger_level.py b/src/pyramid/services/logger_level.py index 698c850..9641eb9 100644 --- a/src/pyramid/services/logger_level.py +++ b/src/pyramid/services/logger_level.py @@ -13,12 +13,12 @@ def injectService(self, logger_service: ILoggerService, configuration_service: IConfigurationService ): - self.__logger_service = logger_service + self.__logger = logger_service self.__configuration_service = configuration_service def start(self): - logger = self.__logger_service.getLogger() - # logger_colored = self.__logger_service.getLogger() + logger = self.__logger.getLogger() + # logger_colored = self.__logger.getLogger() logger_discord = logging.getLogger("discord") # logger_aiohttpweb = logging.getLogger("aiohttpweb") # logger_urllib3 = logging.getLogger("urllib3") diff --git a/src/pyramid/tasks/discord.py b/src/pyramid/tasks/discord.py index 27f3ebd..6f22e66 100644 --- a/src/pyramid/tasks/discord.py +++ b/src/pyramid/tasks/discord.py @@ -15,8 +15,7 @@ def injectService(self, self.__discord_service = discord_service async def worker_asyc(self): - self.discord_bot = self.__discord_service.discord_bot - await self.discord_bot.start() + await self.__discord_service.connect_bot() async def stop_asyc(self): - await self.discord_bot.stop() + await self.__discord_service.disconnect_bot() From ac112cc6a703e35bcb5b2b7172dfe0159d47ca9d Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Sat, 21 Sep 2024 00:01:19 +0200 Subject: [PATCH 12/32] feat: create more service --- src/pyramid/api/services/__init__.py | 1 + src/pyramid/api/services/deezer_downloader.py | 13 + src/pyramid/api/services/deezer_search.py | 70 +++++ src/pyramid/api/services/source_service.py | 23 ++ src/pyramid/api/services/spotify_client.py | 36 +++ src/pyramid/api/services/spotify_search.py | 34 +++ .../api/services/spotify_search_base.py | 24 ++ src/pyramid/api/services/spotify_search_id.py | 28 ++ src/pyramid/api/services/tools/injector.py | 3 +- src/pyramid/connector/deezer/deezer_type.py | 9 + src/pyramid/connector/deezer/downloader.py | 6 +- src/pyramid/connector/deezer/search.py | 39 +-- src/pyramid/connector/deezer/tools.py | 34 +++ src/pyramid/connector/discord/guild_cmd.py | 5 +- .../connector/discord/guild_cmd_tools.py | 4 +- src/pyramid/connector/spotify/cli_spotify.py | 25 +- src/pyramid/connector/spotify/search.py | 44 +-- .../connector/spotify/spotify_tools.py | 30 ++ src/pyramid/connector/spotify/spotify_type.py | 11 + src/pyramid/data/guild_data.py | 6 +- src/pyramid/data/guild_instance.py | 8 +- src/pyramid/data/source_type.py | 8 + src/pyramid/services/deezer_downloader.py | 142 ++++++++++ src/pyramid/services/deezer_search.py | 263 ++++++++++++++++++ src/pyramid/services/discord.py | 9 +- src/pyramid/services/discord_commads.py | 13 +- .../engine_source.py => services/source.py} | 33 ++- src/pyramid/services/spotify_client.py | 138 +++++++++ src/pyramid/services/spotify_search.py | 116 ++++++++ src/pyramid/services/spotify_search_base.py | 65 +++++ src/pyramid/services/spotify_search_id.py | 71 +++++ 31 files changed, 1184 insertions(+), 127 deletions(-) create mode 100644 src/pyramid/api/services/deezer_downloader.py create mode 100644 src/pyramid/api/services/deezer_search.py create mode 100644 src/pyramid/api/services/source_service.py create mode 100644 src/pyramid/api/services/spotify_client.py create mode 100644 src/pyramid/api/services/spotify_search.py create mode 100644 src/pyramid/api/services/spotify_search_base.py create mode 100644 src/pyramid/api/services/spotify_search_id.py create mode 100644 src/pyramid/connector/deezer/deezer_type.py create mode 100644 src/pyramid/connector/deezer/tools.py create mode 100644 src/pyramid/connector/spotify/spotify_tools.py create mode 100644 src/pyramid/connector/spotify/spotify_type.py create mode 100644 src/pyramid/data/source_type.py create mode 100644 src/pyramid/services/deezer_downloader.py create mode 100644 src/pyramid/services/deezer_search.py rename src/pyramid/{data/functional/engine_source.py => services/source.py} (85%) create mode 100644 src/pyramid/services/spotify_client.py create mode 100644 src/pyramid/services/spotify_search.py create mode 100644 src/pyramid/services/spotify_search_base.py create mode 100644 src/pyramid/services/spotify_search_id.py diff --git a/src/pyramid/api/services/__init__.py b/src/pyramid/api/services/__init__.py index 321640f..84fbfed 100644 --- a/src/pyramid/api/services/__init__.py +++ b/src/pyramid/api/services/__init__.py @@ -3,3 +3,4 @@ from .information import IInformationService from .logger import ILoggerService from .socket_server import ISocketServerService +# from .source_service import ISourceService diff --git a/src/pyramid/api/services/deezer_downloader.py b/src/pyramid/api/services/deezer_downloader.py new file mode 100644 index 0000000..396bb9f --- /dev/null +++ b/src/pyramid/api/services/deezer_downloader.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod +from typing import Any +from pyramid.data.track import Track + +class IDeezerDownloaderService(ABC): + + @abstractmethod + async def check_credentials(self) -> dict[str, Any]: + pass + + @abstractmethod + async def dl_track_by_id(self, track_id) -> Track | None: + pass diff --git a/src/pyramid/api/services/deezer_search.py b/src/pyramid/api/services/deezer_search.py new file mode 100644 index 0000000..cfeeeca --- /dev/null +++ b/src/pyramid/api/services/deezer_search.py @@ -0,0 +1,70 @@ + + +from abc import ABC, abstractmethod +from pyramid.data.a_search import ASearch +from pyramid.data.track import TrackMinimalDeezer + + +class IDeezerSearchService(ASearch): + + @abstractmethod + async def search_track(self, search) -> TrackMinimalDeezer | None: + pass + + @abstractmethod + async def get_track_by_id(self, track_id: int) -> TrackMinimalDeezer | None: + pass + + @abstractmethod + async def get_track_by_isrc(self, isrc: str) -> TrackMinimalDeezer | None: + pass + + @abstractmethod + async def search_tracks( + self, search, limit: int | None = None + ) -> list[TrackMinimalDeezer] | None: + pass + + @abstractmethod + async def get_playlist_tracks(self, playlist_name) -> list[TrackMinimalDeezer] | None: + pass + + @abstractmethod + async def get_playlist_tracks_by_id( + self, playlist_id: int + ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: + pass + + @abstractmethod + async def get_album_tracks(self, album_name) -> list[TrackMinimalDeezer] | None: + pass + + @abstractmethod + async def get_album_tracks_by_id( + self, album_id: int + ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: + pass + + @abstractmethod + async def get_top_artist( + self, artist_name, limit: int | None = None + ) -> list[TrackMinimalDeezer] | None: + pass + + @abstractmethod + async def get_top_artist_by_id( + self, artist_id: int, limit: int | None = None + ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: + pass + + @abstractmethod + async def get_by_url( + self, url + ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | TrackMinimalDeezer | None: + pass + + @abstractmethod + async def search_exact_track( + self, artist_name, album_title, track_title + ) -> TrackMinimalDeezer | None: + pass diff --git a/src/pyramid/api/services/source_service.py b/src/pyramid/api/services/source_service.py new file mode 100644 index 0000000..42c1fe4 --- /dev/null +++ b/src/pyramid/api/services/source_service.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod +from pyramid.data.source_type import SourceType +from pyramid.data.track import Track, TrackMinimal, TrackMinimalDeezer + +class ISourceService(ABC): + + @abstractmethod + async def download_track(self, track: TrackMinimal) -> Track | None: + pass + + @abstractmethod + async def search_by_url(self, url: str) -> (tuple[list[TrackMinimalDeezer] | list[TrackMinimal], list[TrackMinimal]] | TrackMinimalDeezer): + pass + + @abstractmethod + async def search_track(self, input: str, engine: SourceType | None) -> TrackMinimalDeezer: + pass + + @abstractmethod + async def search_tracks( + self, input: str, engine: SourceType | None, limit: int | None = None + ) -> tuple[list[TrackMinimal], list[TrackMinimal]]: + pass diff --git a/src/pyramid/api/services/spotify_client.py b/src/pyramid/api/services/spotify_client.py new file mode 100644 index 0000000..6c52bdd --- /dev/null +++ b/src/pyramid/api/services/spotify_client.py @@ -0,0 +1,36 @@ +from abc import abstractmethod +from typing import Any + +class ISpotifyClientService: + + @abstractmethod + async def async_search(self, q, limit=10, offset=0, type="track", market=None) -> dict[str, Any]: + pass + + @abstractmethod + async def async_track(self, track_id, market=None) -> dict[str, Any]: + pass + + @abstractmethod + async def async_playlist_items( + self, + playlist_id, + fields=None, + limit=100, + offset=0, + market=None, + additional_types=("track", "episode"), + ) -> dict[str, Any]: + pass + + @abstractmethod + async def async_album_tracks(self, album_id, limit=50, offset=0, market=None) -> dict[str, Any]: + pass + + @abstractmethod + async def async_artist_top_tracks(self, artist_id, country="US") -> dict[str, Any]: + pass + + @abstractmethod + async def async_next(self, result) -> dict[str, Any] | None: + pass diff --git a/src/pyramid/api/services/spotify_search.py b/src/pyramid/api/services/spotify_search.py new file mode 100644 index 0000000..ee7d06f --- /dev/null +++ b/src/pyramid/api/services/spotify_search.py @@ -0,0 +1,34 @@ +from abc import abstractmethod +from pyramid.data.a_search import ASearch +from pyramid.data.track import TrackMinimalSpotify + + +class ISpotifySearchService(ASearch): + + @abstractmethod + async def search_tracks( + self, search, limit: int | None = None + ) -> list[TrackMinimalSpotify] | None: + pass + + @abstractmethod + async def search_track(self, search) -> TrackMinimalSpotify | None: + pass + + @abstractmethod + async def get_playlist_tracks(self, playlist_name) -> list[TrackMinimalSpotify] | None: + pass + + @abstractmethod + async def get_album_tracks(self, album_name) -> list[TrackMinimalSpotify] | None: + pass + + async def get_top_artist( + self, artist_name, limit: int | None = None + ) -> list[TrackMinimalSpotify] | None: + pass + + async def get_by_url( + self, url + ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | TrackMinimalSpotify | None: + pass \ No newline at end of file diff --git a/src/pyramid/api/services/spotify_search_base.py b/src/pyramid/api/services/spotify_search_base.py new file mode 100644 index 0000000..0f6492d --- /dev/null +++ b/src/pyramid/api/services/spotify_search_base.py @@ -0,0 +1,24 @@ + + + +from abc import ABC, abstractmethod +from typing import Any + +class ISpotifySearchBaseService(ABC): + + @abstractmethod + async def items( + self, + results: dict[str, Any], + item_name="items" + ) -> list[dict[str, Any]] | None: + pass + + @abstractmethod + async def items_max( + self, + results: dict[str, Any], + limit: int | None = None, + item_name="items" + ) -> list[Any] | None: + pass diff --git a/src/pyramid/api/services/spotify_search_id.py b/src/pyramid/api/services/spotify_search_id.py new file mode 100644 index 0000000..373edcd --- /dev/null +++ b/src/pyramid/api/services/spotify_search_id.py @@ -0,0 +1,28 @@ +from abc import abstractmethod + +from pyramid.data.a_search import ASearchId +from pyramid.data.track import TrackMinimalSpotify + +class ISpotifySearchIdService(ASearchId): + + @abstractmethod + async def get_track_by_id(self, track_id: str) -> TrackMinimalSpotify | None: + pass + + @abstractmethod + async def get_playlist_tracks_by_id( + self, playlist_id: str + ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | None: + pass + + @abstractmethod + async def get_album_tracks_by_id( + self, album_id: str + ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | None: + pass + + @abstractmethod + async def get_top_artist_by_id( + self, artist_id: str, limit: int | None = None + ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | None: + pass diff --git a/src/pyramid/api/services/tools/injector.py b/src/pyramid/api/services/tools/injector.py index 8849087..2eb2359 100644 --- a/src/pyramid/api/services/tools/injector.py +++ b/src/pyramid/api/services/tools/injector.py @@ -1,5 +1,4 @@ -from abc import ABC, abstractmethod -from discord.ext.commands import Bot +from abc import ABC class ServiceInjector(ABC): diff --git a/src/pyramid/connector/deezer/deezer_type.py b/src/pyramid/connector/deezer/deezer_type.py new file mode 100644 index 0000000..15c5d0d --- /dev/null +++ b/src/pyramid/connector/deezer/deezer_type.py @@ -0,0 +1,9 @@ + +from enum import Enum + + +class DeezerType(Enum): + PLAYLIST = 1 + ARTIST = 2 + ALBUM = 3 + TRACK = 4 diff --git a/src/pyramid/connector/deezer/downloader.py b/src/pyramid/connector/deezer/downloader.py index e86e375..2c3165d 100644 --- a/src/pyramid/connector/deezer/downloader.py +++ b/src/pyramid/connector/deezer/downloader.py @@ -9,12 +9,14 @@ from pyramid.connector.deezer.py_deezer import PyDeezer from pyramid.data.track import Track from pyramid.tools.generate_token import DeezerTokenProvider +from pyramid.tools.deprecated_class import deprecated_class +from pyramid.tools.generate_token import DeezerTokenProvider from pydeezer.constants import track_formats from urllib3.exceptions import MaxRetryError from pyramid.data.exceptions import CustomException, DeezerTokensUnavailableException, DeezerTokenInvalidException, DeezerTokenOverflowException - +@deprecated_class class DeezerDownloader: def __init__(self, folder: str, arl: Optional[str] = None): self.folder_path = folder @@ -77,7 +79,7 @@ async def __dl_track(self, track_info, file_name: str) -> bool: except CustomException as error: trace = "".join(traceback.format_exception(type(error), error, error.__traceback__)) - logging.warning("%s :\n%s", error.msg, trace) + logging.warning("%s :\n%s", error.args, trace) return False except Exception: diff --git a/src/pyramid/connector/deezer/search.py b/src/pyramid/connector/deezer/search.py index dd1b925..fad9c3b 100644 --- a/src/pyramid/connector/deezer/search.py +++ b/src/pyramid/connector/deezer/search.py @@ -15,8 +15,10 @@ CliDeezerRateLimitError, CliPaginatedList, ) +from pyramid.services.deezer_search import DeezerTools, DeezerType +from pyramid.tools.deprecated_class import deprecated_class - +@deprecated_class class DeezerSearch(ASearchId, ASearch): def __init__(self, default_limit: int): self.default_limit = default_limit @@ -250,38 +252,3 @@ def __remove_special_chars( return "".join(result) - -class DeezerType(Enum): - PLAYLIST = 1 - ARTIST = 2 - ALBUM = 3 - TRACK = 4 - - -class DeezerTools(AEngineTools): - async def extract_from_url(self, url) -> tuple[int, DeezerType | None] | tuple[None, None]: - # Resolve if URL is a deezer.page.link URL - if "deezer.page.link" in url: - async with aiohttp.ClientSession() as session: - async with session.get(url, allow_redirects=True) as response: - url = str(response.url) - - # Extract ID and type using regex - pattern = r"(?<=deezer.com/fr/)(\w+)/(?P\d+)" - match = re.search(pattern, url) - if not match: - return None, None - deezer_type_str = match.group(1).upper() - if deezer_type_str == "PLAYLIST": - deezer_type = DeezerType.PLAYLIST - elif deezer_type_str == "ARTIST": - deezer_type = DeezerType.ARTIST - elif deezer_type_str == "ALBUM": - deezer_type = DeezerType.ALBUM - elif deezer_type_str == "TRACK": - deezer_type = DeezerType.TRACK - else: - deezer_type = None - - deezer_id = int(match.group("id")) - return deezer_id, deezer_type diff --git a/src/pyramid/connector/deezer/tools.py b/src/pyramid/connector/deezer/tools.py new file mode 100644 index 0000000..a66ecb8 --- /dev/null +++ b/src/pyramid/connector/deezer/tools.py @@ -0,0 +1,34 @@ + +import re +import aiohttp +from pyramid.data.a_engine_tools import AEngineTools +from pyramid.services.deezer_search import DeezerType + + +class DeezerTools(AEngineTools): + async def extract_from_url(self, url) -> tuple[int, DeezerType | None] | tuple[None, None]: + # Resolve if URL is a deezer.page.link URL + if "deezer.page.link" in url: + async with aiohttp.ClientSession() as session: + async with session.get(url, allow_redirects=True) as response: + url = str(response.url) + + # Extract ID and type using regex + pattern = r"(?<=deezer.com/fr/)(\w+)/(?P\d+)" + match = re.search(pattern, url) + if not match: + return None, None + deezer_type_str = match.group(1).upper() + if deezer_type_str == "PLAYLIST": + deezer_type = DeezerType.PLAYLIST + elif deezer_type_str == "ARTIST": + deezer_type = DeezerType.ARTIST + elif deezer_type_str == "ALBUM": + deezer_type = DeezerType.ALBUM + elif deezer_type_str == "TRACK": + deezer_type = DeezerType.TRACK + else: + deezer_type = None + + deezer_id = int(match.group("id")) + return deezer_id, deezer_type diff --git a/src/pyramid/connector/discord/guild_cmd.py b/src/pyramid/connector/discord/guild_cmd.py index 3083564..f5f791f 100644 --- a/src/pyramid/connector/discord/guild_cmd.py +++ b/src/pyramid/connector/discord/guild_cmd.py @@ -3,13 +3,14 @@ from discord import Interaction, Member, User, VoiceChannel import discord +from pyramid.api.services.source_service import ISourceService from pyramid.data import tracklist as utils_list_track from pyramid.data.guild_data import GuildData +from pyramid.data.source_type import SourceType from pyramid.data.track import TrackMinimal from pyramid.connector.discord.guild_cmd_tools import GuildCmdTools from pyramid.connector.discord.guild_queue import GuildQueue from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued -from pyramid.data.functional.engine_source import EngineSource, SourceType from pyramid.data.exceptions import DiscordMessageException from pyramid.data.a_guild_cmd import AGuildCmd from pyramid.data.select_view import SelectView @@ -21,7 +22,7 @@ def __init__( logger: Logger, guild_data: GuildData, guild_queue: GuildQueue, - engine_source: EngineSource, + engine_source: ISourceService, ): self.logger = logger self.engine_source = engine_source diff --git a/src/pyramid/connector/discord/guild_cmd_tools.py b/src/pyramid/connector/discord/guild_cmd_tools.py index 3b12f56..baa78ca 100644 --- a/src/pyramid/connector/discord/guild_cmd_tools.py +++ b/src/pyramid/connector/discord/guild_cmd_tools.py @@ -5,12 +5,12 @@ from discord import Member, StageChannel, TextChannel, User, VoiceChannel, VoiceClient, VoiceState from pyramid.data.exceptions import DeezerTokenException +from pyramid.api.services.source_service import ISourceService from pyramid.data.track import Track, TrackMinimal, TrackMinimalDeezer from pyramid.data.guild_data import GuildData from pyramid.data.tracklist import TrackList from pyramid.connector.discord.guild_queue import GuildQueue from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued -from pyramid.data.functional.engine_source import EngineSource class GuildCmdTools: @@ -18,7 +18,7 @@ def __init__( self, guild_data: GuildData, guild_queue: GuildQueue, - engine_source: EngineSource, + engine_source: ISourceService, ): self.engine_source = engine_source self.data = guild_data diff --git a/src/pyramid/connector/spotify/cli_spotify.py b/src/pyramid/connector/spotify/cli_spotify.py index a6d9f62..202e981 100644 --- a/src/pyramid/connector/spotify/cli_spotify.py +++ b/src/pyramid/connector/spotify/cli_spotify.py @@ -1,19 +1,23 @@ import json import logging +from typing import Any from venv import logger import aiohttp from spotipy import Spotify from spotipy.exceptions import SpotifyException +from pyramid.tools.deprecated_class import deprecated_class +@deprecated_class class CliSpotify(Spotify): - async def async_search(self, q, limit=10, offset=0, type="track", market=None): + + async def async_search(self, q, limit=10, offset=0, type="track", market=None) -> dict[str, Any]: return await self._get_async( "search", q=q, limit=limit, offset=offset, type=type, market=market ) - async def async_track(self, track_id, market=None): + async def async_track(self, track_id, market=None) -> dict[str, Any]: trid = self._get_id("track", track_id) return await self._get_async("tracks/" + trid, market=market) @@ -25,7 +29,7 @@ async def async_playlist_items( offset=0, market=None, additional_types=("track", "episode"), - ): + ) -> dict[str, Any]: plid = self._get_id("playlist", playlist_id) return await self._get_async( "playlists/%s/tracks" % (plid), @@ -36,29 +40,28 @@ async def async_playlist_items( additional_types=",".join(additional_types), ) - async def async_album_tracks(self, album_id, limit=50, offset=0, market=None): + async def async_album_tracks(self, album_id, limit=50, offset=0, market=None) -> dict[str, Any]: trid = self._get_id("album", album_id) return await self._get_async( "albums/" + trid + "/tracks/", limit=limit, offset=offset, market=market ) - async def async_artist_top_tracks(self, artist_id, country="US"): + async def async_artist_top_tracks(self, artist_id, country="US") -> dict[str, Any]: trid = self._get_id("artist", artist_id) return await self._get_async("artists/" + trid + "/top-tracks", country=country) - async def async_next(self, result): - if result["next"]: - return await self._get_async(result["next"]) - else: + async def async_next(self, result) -> dict[str, Any] | None: + if not result["next"]: return None + return await self._get_async(result["next"]) - async def _get_async(self, url, args=None, payload=None, **kwargs): + async def _get_async(self, url, args=None, payload=None, **kwargs) -> dict[str, Any]: if args: kwargs.update(args) return await self._async_internal_call("GET", url, payload, kwargs) - async def _async_internal_call(self, method, url, payload, params): + async def _async_internal_call(self, method: str, url: str, payload, params) -> dict[str, Any]: args = dict(params=params) if not url.startswith("http"): url = self.prefix + url diff --git a/src/pyramid/connector/spotify/search.py b/src/pyramid/connector/spotify/search.py index f4e3c4c..df02dd0 100644 --- a/src/pyramid/connector/spotify/search.py +++ b/src/pyramid/connector/spotify/search.py @@ -1,15 +1,15 @@ -import re -from enum import Enum from typing import Any -from pyramid.data.a_engine_tools import AEngineTools +from pyramid.connector.spotify.spotify_tools import SpotifyTools +from pyramid.connector.spotify.spotify_type import SpotifyType from pyramid.data.a_search import ASearch, ASearchId from pyramid.data.track import TrackMinimalSpotify from spotipy.oauth2 import SpotifyClientCredentials from pyramid.connector.spotify.cli_spotify import CliSpotify +from pyramid.tools.deprecated_class import deprecated_class - +@deprecated_class class SpotifySearchBase(ASearch): def __init__(self, default_limit: int, client_id: str, client_secret: str): self.default_limit = default_limit @@ -19,7 +19,6 @@ def __init__(self, default_limit: int, client_id: str, client_secret: str): client_id=self.client_id, client_secret=self.client_secret ) self.client = CliSpotify(client_credentials_manager=self.client_credentials_manager) - self.tools = SpotifyTools() async def items( self, results: dict[str, Any], item_name="items" @@ -53,6 +52,7 @@ async def items_max(self, results: dict[str, Any], limit: int | None = None, ite return tracks +@deprecated_class class SpotifySearchId(ASearchId, SpotifySearchBase): def __init__(self, default_limit: int, client_id: str, client_secret: str): super().__init__(default_limit, client_id, client_secret) @@ -105,7 +105,7 @@ async def get_top_artist_by_id( tracks = tracks[:limit] return [TrackMinimalSpotify(element) for element in tracks], [] - +@deprecated_class class SpotifyResponse: def __init__(self, client: CliSpotify, default_limit: int, item_name="items") -> None: self.client = client @@ -131,6 +131,7 @@ async def items(self, results: dict[str, Any], limit: int | None = None): return tracks +@deprecated_class class SpotifySearch(SpotifySearchId): def __init__(self, default_limit: int, client_id: str, client_secret: str): super().__init__(default_limit, client_id, client_secret) @@ -202,7 +203,7 @@ async def get_top_artist( async def get_by_url( self, url ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | TrackMinimalSpotify | None: - id, type = self.tools.extract_from_url(url) + id, type = SpotifyTools.extract_from_url(url) if id is None: return None @@ -226,32 +227,3 @@ async def get_by_url( return tracks - -class SpotifyType(Enum): - PLAYLIST = 1 - ARTIST = 2 - ALBUM = 3 - TRACK = 4 - - -class SpotifyTools(AEngineTools): - def extract_from_url(self, url) -> tuple[str, SpotifyType | None] | tuple[None, None]: - # Extract ID and type using regex - pattern = r"(?<=open\.spotify\.com/)(intl-(?P\w+)/)?(?P\w+)/(?P\w+)" - match = re.search(pattern, url) - if not match: - return None, None - type_str = match.group("type").upper() - if type_str == "PLAYLIST": - type = SpotifyType.PLAYLIST - elif type_str == "ARTIST": - type = SpotifyType.ARTIST - elif type_str == "ALBUM": - type = SpotifyType.ALBUM - elif type_str == "TRACK": - type = SpotifyType.TRACK - else: - type = None - - id = match.group("id") - return id, type diff --git a/src/pyramid/connector/spotify/spotify_tools.py b/src/pyramid/connector/spotify/spotify_tools.py new file mode 100644 index 0000000..53b013b --- /dev/null +++ b/src/pyramid/connector/spotify/spotify_tools.py @@ -0,0 +1,30 @@ + + +import re +from pyramid.connector.spotify.spotify_type import SpotifyType +from pyramid.data.a_engine_tools import AEngineTools + + +class SpotifyTools(AEngineTools): + + @classmethod + def extract_from_url(cls, url) -> tuple[str, SpotifyType | None] | tuple[None, None]: + # Extract ID and type using regex + pattern = r"(?<=open\.spotify\.com/)(intl-(?P\w+)/)?(?P\w+)/(?P\w+)" + match = re.search(pattern, url) + if not match: + return None, None + type_str = match.group("type").upper() + if type_str == "PLAYLIST": + type = SpotifyType.PLAYLIST + elif type_str == "ARTIST": + type = SpotifyType.ARTIST + elif type_str == "ALBUM": + type = SpotifyType.ALBUM + elif type_str == "TRACK": + type = SpotifyType.TRACK + else: + type = None + + id = match.group("id") + return id, type diff --git a/src/pyramid/connector/spotify/spotify_type.py b/src/pyramid/connector/spotify/spotify_type.py new file mode 100644 index 0000000..286bb85 --- /dev/null +++ b/src/pyramid/connector/spotify/spotify_type.py @@ -0,0 +1,11 @@ + + + +from enum import Enum + + +class SpotifyType(Enum): + PLAYLIST = 1 + ARTIST = 2 + ALBUM = 3 + TRACK = 4 diff --git a/src/pyramid/data/guild_data.py b/src/pyramid/data/guild_data.py index 6f5f23e..c7ba7f8 100644 --- a/src/pyramid/data/guild_data.py +++ b/src/pyramid/data/guild_data.py @@ -1,12 +1,12 @@ from discord import Guild, VoiceClient +from pyramid.api.services.source_service import ISourceService from pyramid.data.tracklist import TrackList -from pyramid.data.functional.engine_source import EngineSource class GuildData: - def __init__(self, guild: Guild, engine_source: EngineSource): + def __init__(self, guild: Guild, source_service: ISourceService): self.guild: Guild = guild self.track_list: TrackList = TrackList() self.voice_client: VoiceClient = None # type: ignore - self.search_engine = engine_source + self.search_engine = source_service diff --git a/src/pyramid/data/guild_instance.py b/src/pyramid/data/guild_instance.py index 2ba42ba..bf44d12 100644 --- a/src/pyramid/data/guild_instance.py +++ b/src/pyramid/data/guild_instance.py @@ -3,17 +3,17 @@ from logging import Logger from discord import Guild +from pyramid.api.services.source_service import ISourceService from pyramid.connector.discord.guild_cmd import GuildCmd from pyramid.connector.discord.guild_queue import GuildQueue from pyramid.connector.discord.music_player_interface import MusicPlayerInterface -from pyramid.data.functional.engine_source import EngineSource from pyramid.data.guild_data import GuildData class GuildInstances: - def __init__(self, guild: Guild, logger: Logger, engine_source: EngineSource, ffmpeg_path: str): - self.data = GuildData(guild, engine_source) + def __init__(self, guild: Guild, logger: Logger, source_service: ISourceService, ffmpeg_path: str): + self.data = GuildData(guild, source_service) self.mpi = MusicPlayerInterface(self.data.guild.preferred_locale, self.data.track_list) self.songs = GuildQueue(self.data, ffmpeg_path, self.mpi) - self.cmds = GuildCmd(logger, self.data, self.songs, engine_source) + self.cmds = GuildCmd(logger, self.data, self.songs, source_service) self.mpi.set_queue_action(self.cmds) diff --git a/src/pyramid/data/source_type.py b/src/pyramid/data/source_type.py new file mode 100644 index 0000000..4426344 --- /dev/null +++ b/src/pyramid/data/source_type.py @@ -0,0 +1,8 @@ + + +from enum import Enum + + +class SourceType(Enum): + Spotify = 1 + Deezer = 2 diff --git a/src/pyramid/services/deezer_downloader.py b/src/pyramid/services/deezer_downloader.py new file mode 100644 index 0000000..858cc80 --- /dev/null +++ b/src/pyramid/services/deezer_downloader.py @@ -0,0 +1,142 @@ +import asyncio +import os +import traceback +from typing import Any + +import pydeezer.util +from pyramid.api.services import IConfigurationService, ILoggerService +from pyramid.api.services.deezer_downloader import IDeezerDownloaderService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.connector.deezer.downloader_progress_bar import DownloaderProgressBar +from pyramid.connector.deezer.py_deezer import PyDeezer +from pyramid.data.track import Track +from pyramid.tools.generate_token import DeezerTokenProvider, DeezerTokenEmptyException, DeezerTokenOverflowException +from pydeezer.constants import track_formats +from pydeezer.exceptions import LoginError +from urllib3.exceptions import MaxRetryError + +from pyramid.data.exceptions import CustomException + + +@pyramid_service(interface=IDeezerDownloaderService) +class DeezerDownloaderService(IDeezerDownloaderService, ServiceInjector): + + def injectService(self, + logger_service: ILoggerService, + configuration_service: IConfigurationService + ): + self.__logger = logger_service + self.__configuration_service = configuration_service + + def start(self): + # arl = self.__configuration_service.deezer__arl + arl = None + if arl is not None and arl != "": + self.__deezer_dl_api = PyDeezer(arl) + self.__token_provider = None + else: + self.__deezer_dl_api = None + self.__token_provider = DeezerTokenProvider() + self.music_format = track_formats.MP3_128 + os.makedirs(self.__configuration_service.deezer__folder, exist_ok=True) + + async def check_credentials(self) -> dict[str, Any]: + if not self.__deezer_dl_api: + raise Exception("deezer_dl_api not init") + try: + await self.__deezer_dl_api.get_user_data() + return self.__deezer_dl_api.user + except LoginError as err: + raise err # Arl is invalid + + async def dl_track_by_id(self, track_id) -> Track | None: + client = await self._get_client() + # try: + track_info = await client.get_track_info(track_id) + # except APIRequestError as err: + # self.__logger.warn(f"Unable to download deezer song {track_id} : {err}", exc_info=True) + # return None # Track unvailable in this country + + if not track_info: + self.__logger.error(f"Unable to find deezer song to download {track_id} : Unknown error") + return None + + file_name = pydeezer.util.clean_filename( + f"{track_info['ART_NAME']} - {track_info['SNG_TITLE']}" + ) + file_path = os.path.join(self.__configuration_service.deezer__folder, file_name) + ".mp3" + + if os.path.exists(file_path) is False: + is_dl = await self.__dl_track(track_info, file_name) + if not is_dl: + return None + + track_downloaded = Track(track_info, file_path) + return track_downloaded + + async def __dl_track(self, track_info, file_name: str) -> bool: + try: + client = await self._get_client() + await client.download_track( + track_info, + self.__configuration_service.deezer__folder, + self.music_format, + True, # fallback quality if not available + file_name, + False, # renew track info + False, # metadata + False, # lyrics + ", ", # separator for multiple artists + False, # show messages + DownloaderProgressBar(), # Custom progress bar + ) + return True + except MaxRetryError: + track = Track(track_info, None) + self.__logger.warning("Downloader MaxRetryError %s", track) + await asyncio.sleep(5) + return await self.__dl_track(track_info, file_name) + + except CustomException as error: + trace = "".join(traceback.format_exception(type(error), error, error.__traceback__)) + self.__logger.warning("%s :\n%s", error.msg, trace) + return False + + except Exception: + track = Track(track_info, None) + self.__logger.warning("Unable to dl track %s", track, exc_info=True) + return False + + async def _get_client(self) -> PyDeezer: + i = 0 + max_error = 10 + + if self.__deezer_dl_api: + return self.__deezer_dl_api + if not self.__token_provider: + raise Exception("token_provider not init") + + while True: + try: + token = self.__token_provider.next() + self.__deezer_dl_api = PyDeezer(token.token) + await self.check_credentials() + break + + except DeezerTokenEmptyException as err: + if i > max_error: + raise err + self.__token_provider = DeezerTokenProvider() + + except DeezerTokenOverflowException as err: + if i > max_error: + raise err + self.__token_provider = DeezerTokenProvider() + + except LoginError as err: + if i > max_error: + raise err + i += 1 + + return self.__deezer_dl_api diff --git a/src/pyramid/services/deezer_search.py b/src/pyramid/services/deezer_search.py new file mode 100644 index 0000000..f4f4669 --- /dev/null +++ b/src/pyramid/services/deezer_search.py @@ -0,0 +1,263 @@ +import asyncio +import logging +import re +from enum import Enum +import aiohttp + +import deezer +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.deezer_search import IDeezerSearchService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.connector.deezer.deezer_type import DeezerType +from pyramid.connector.deezer.tools import DeezerTools +from pyramid.data.a_search import ASearch, ASearchId +from pyramid.data.track import TrackMinimalDeezer + +from pyramid.connector.deezer.cli_deezer import ( + CliDeezer, + CliDeezerNoDataException, + CliDeezerRateLimitError, + CliPaginatedList, +) + + +@pyramid_service(interface=IDeezerSearchService) +class DeezerSearchService(IDeezerSearchService, ASearchId, ASearch, ServiceInjector): + + def injectService(self, + configuration_service: IConfigurationService + ): + self.__configuration_service = configuration_service + + def start(self): + self.client = CliDeezer() + self.tools = DeezerTools() + self.strict = False + + async def search_track(self, search) -> TrackMinimalDeezer | None: + result = self.client.search(query=search) + track = await result.get_first() + if not track: + return None + return TrackMinimalDeezer(track) + + async def get_track_by_id(self, track_id: int) -> TrackMinimalDeezer | None: + track = await self.client.async_get_track(track_id) # TODO handle HTTP errors + if not track: + return None + return TrackMinimalDeezer(track) + + async def get_track_by_isrc(self, isrc: str) -> TrackMinimalDeezer | None: + try: + track: deezer.Track = await self.client.async_request("GET", f"track/isrc:{isrc}") # type: ignore + if not track: + return None + return TrackMinimalDeezer(track) + except CliDeezerNoDataException: + return None + + async def search_tracks( + self, search, limit: int | None = None + ) -> list[TrackMinimalDeezer] | None: + if limit is None: + limit = self.__configuration_service.general__limit_tracks + + pagination_results = self.client.search(query=search, strict=self.strict) + tracks = await pagination_results.get_maximum(limit) + if not tracks: + return None + + return [TrackMinimalDeezer(element) for element in tracks] + + async def get_playlist_tracks(self, playlist_name) -> list[TrackMinimalDeezer] | None: + pagination_results = self.client.search_playlists(query=playlist_name, strict=self.strict) + playlist = await pagination_results.get_first() + if not playlist: + return None + pagination_tracks: CliPaginatedList[deezer.Track] = playlist.get_tracks() # type: ignore + tracks = await pagination_tracks.get_all() + return [TrackMinimalDeezer(element) for element in tracks] + + async def get_playlist_tracks_by_id( + self, playlist_id: int + ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: + playlist = await self.client.async_get_playlist(playlist_id) # TODO handle HTTP errors + if not playlist: + return None + # Tracks id are the not the good one + playlist_tracks: CliPaginatedList[deezer.Track] = playlist.get_tracks() # type: ignore + + # So we search the id for same name and artist + real_tracks: list[TrackMinimalDeezer] = [] + unfindable_track: list[TrackMinimalDeezer] = [] + + async for chunk_tracks in playlist_tracks: + for t in chunk_tracks: + track = await self.search_exact_track(t.artist.name, t.album.title, t.title) + # logging.info("DEBUG song '%s' - '%s' - '%s'", t.artist.name, t.title, t.album.title) + if track is None: + if not t.readable: + logging.warning( + "Unavailable track in playlist '%s' - '%s'", t.artist.name, t.title + ) + else: + logging.warning( + "Unknown track searched in playlist '%s' - '%s'", t.artist.name, t.title + ) + unfindable_track.append(TrackMinimalDeezer(t)) + continue + real_tracks.append(track) + + return real_tracks, unfindable_track + + async def get_album_tracks(self, album_name) -> list[TrackMinimalDeezer] | None: + pagination_results = self.client.search_albums(query=album_name, strict=self.strict) + album = await pagination_results.get_first() + if not album: + return None + pagination_tracks: CliPaginatedList[deezer.Track] = album.get_tracks() # type: ignore + tracks = await pagination_tracks.get_all() + return [TrackMinimalDeezer(element) for element in tracks] + + async def get_album_tracks_by_id( + self, album_id: int + ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: + album = await self.client.async_get_album(album_id) # TODO handle HTTP errors + if not album: + return None + pagination_tracks: CliPaginatedList[deezer.Track] = album.get_tracks() # type: ignore + tracks = await pagination_tracks.get_all() + return [TrackMinimalDeezer(element) for element in tracks], [] + + async def get_top_artist( + self, artist_name, limit: int | None = None + ) -> list[TrackMinimalDeezer] | None: + if limit is None: + limit = self.__configuration_service.general__limit_tracks + pagination_results = self.client.search_artists(query=artist_name, strict=self.strict) + artist = await pagination_results.get_first() + if not artist: + return None + pagination_tracks: CliPaginatedList[deezer.Track] = artist.get_top() # type: ignore + tracks = await pagination_tracks.get_maximum(limit) + return [TrackMinimalDeezer(element) for element in tracks] + + async def get_top_artist_by_id( + self, artist_id: int, limit: int | None = None + ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: + if limit is None: + limit = self.__configuration_service.general__limit_tracks + artist = await self.client.async_get_artist(artist_id) # TODO handle HTTP errors + if not artist: + return None + pagination_tracks: CliPaginatedList[deezer.Track] = artist.get_top() # type: ignore + tracks = await pagination_tracks.get_maximum(limit) + return [TrackMinimalDeezer(element) for element in tracks], [] + + async def get_by_url( + self, url + ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | TrackMinimalDeezer | None: + id, type = await self.tools.extract_from_url(url) + + if id is None: + return None + if type is None: + raise NotImplementedError(f"The type of deezer info '{url}' is not implemented") + + tracks: ( + tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | TrackMinimalDeezer | None + ) + + if type == DeezerType.PLAYLIST: + # future = asyncio.get_event_loop().run_in_executor( + # None, self.get_playlist_tracks_by_id, id + # ) + # tracks = await asyncio.wrap_future(future) + tracks = await self.get_playlist_tracks_by_id(id) + elif type == DeezerType.ARTIST: + tracks = await self.get_top_artist_by_id(id) + elif type == DeezerType.ALBUM: + tracks = await self.get_album_tracks_by_id(id) + elif type == DeezerType.TRACK: + tracks = await self.get_track_by_id(id) + else: + raise NotImplementedError(f"The type of deezer info '{type}' can't be resolve") + + return tracks + + async def search_exact_track( + self, artist_name, album_title, track_title + ) -> TrackMinimalDeezer | None: + clean_artist = self.__remove_special_chars(artist_name) + clean_album = self.__remove_special_chars(album_title) + clean_track = self.__remove_special_chars(track_title) + # logging.info("Song CLEANED '%s' - '%s' - '%s'", clean_artist, clean_track, clean_album) + + track = await self._search_exact_track(clean_artist, clean_album, clean_track) + if track is None: + track = await self._search_exact_track(clean_artist, None, clean_track) + if track is None: + track = await self._search_exact_track(None, clean_album, clean_track) + if track is None: + track = await self._search_exact_track(None, None, clean_track) + # if track is not None: + # logging.warning("Find with title '%s' - '%s' - '%s'", clean_artist, clean_track, clean_album) + # else: + # logging.warning("Find with album & title '%s' - '%s' - '%s'", clean_artist, clean_track, clean_album) + # else: + # logging.warning("Find with artist & title '%s' - '%s' - '%s'", clean_artist, clean_track, clean_album) + return track + + async def _search_exact_track( + self, artist_name, album_title, track_title + ) -> TrackMinimalDeezer | None: + try: + pagination_results = self.client.search( + artist=artist_name, album=album_title, track=track_title + ) + logging.info("_search_exact_track %s - %s - %s", artist_name, album_title, track_title) + track = await pagination_results.get_first() + if track is None: + return None + return TrackMinimalDeezer(track) + + except CliDeezerRateLimitError: + logging.error("Search Deezer RateLimit %s - %s", artist_name, track_title) + await asyncio.sleep(5) + return await self._search_exact_track(artist_name, album_title, track_title) + + def __remove_special_chars( + self, input_string: str | None, allowed_brackets: tuple = ("(", ")", "[", "]") + ): + if input_string is None or input_string == "": + return None + + open_brackets = [b for i, b in enumerate(allowed_brackets) if i % 2 == 0] + close_brackets = [b for i, b in enumerate(allowed_brackets) if i % 2 != 0] + stack: list[str] = [] + result: list[str] = [] + last_char: str | None = None # Keep track of the last processed character + + for char in input_string: + if char in open_brackets: + stack.append(char) + elif char in close_brackets: + if stack: + open_bracket = stack.pop() + if open_brackets.index(open_bracket) == close_brackets.index(char): + continue + if last_char != " ": # Append only if the previous character is not a space + result.append(char) + elif char.isspace(): + if last_char != " ": # Append only if the previous character is not a space + result.append(char) + # elif not stack and (char.isalnum() or char == "'" or char == "/"): + elif not stack: + result.append(char) + else: + continue + last_char = char # Update last_char + + return "".join(result) + diff --git a/src/pyramid/services/discord.py b/src/pyramid/services/discord.py index a7e8109..21bc09c 100644 --- a/src/pyramid/services/discord.py +++ b/src/pyramid/services/discord.py @@ -17,9 +17,9 @@ ) from discord.app_commands.errors import AppCommandError, CommandInvokeError from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.source_service import ISourceService from pyramid.data.environment import Environment from pyramid.connector.discord.guild_cmd import GuildCmd -from pyramid.data.functional.engine_source import EngineSource from pyramid.data.exceptions import DiscordMessageException from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued from pyramid.data.guild_instance import GuildInstances @@ -35,14 +35,15 @@ class DiscordBotService(IDiscordService, ServiceInjector): def injectService(self, logger_service: ILoggerService, information_service: IInformationService, - configuration_service: IConfigurationService + configuration_service: IConfigurationService, + source_service: ISourceService ): self.__logger = logger_service self.__information_service = information_service self.__configuration_service = configuration_service + self.__source_service = source_service def start(self): - self.__engine_source = EngineSource(self.__configuration_service) intents = discord.Intents.default() # intents.members = True @@ -143,7 +144,7 @@ async def disconnect_bot(self): def get_guild_cmd(self, guild: Guild) -> GuildCmd: if guild.id not in self.guilds_instances: self.guilds_instances[guild.id] = GuildInstances( - guild, self.__logger.getChild(guild.name), self.__engine_source, self.__configuration_service.discord__ffmpeg + guild, self.__logger.getChild(guild.name), self.__source_service, self.__configuration_service.discord__ffmpeg ) return self.guilds_instances[guild.id].cmds diff --git a/src/pyramid/services/discord_commads.py b/src/pyramid/services/discord_commads.py index 5efba02..2e2db6c 100644 --- a/src/pyramid/services/discord_commads.py +++ b/src/pyramid/services/discord_commads.py @@ -1,20 +1,11 @@ -from logging import Logger -from typing import Callable - from discord import Guild, Interaction -from discord.ext.commands import Bot -from pyramid.api.services.configuration import IConfigurationService -from pyramid.api.services.discord import IDiscordService -from pyramid.api.services.information import IInformationService -from pyramid.api.services.logger import ILoggerService +from pyramid.api.services import IConfigurationService, IDiscordService, ILoggerService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector from pyramid.connector.discord.commands.tools.register import CommandRegister from pyramid.connector.discord.guild_cmd import GuildCmd -from pyramid.data.environment import Environment -from pyramid.data.functional.application_info import ApplicationInfo from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued -from pyramid.data.functional.engine_source import SourceType +from pyramid.data.source_type import SourceType @pyramid_service() diff --git a/src/pyramid/data/functional/engine_source.py b/src/pyramid/services/source.py similarity index 85% rename from src/pyramid/data/functional/engine_source.py rename to src/pyramid/services/source.py index 4bfc568..8386dbd 100644 --- a/src/pyramid/data/functional/engine_source.py +++ b/src/pyramid/services/source.py @@ -1,28 +1,33 @@ -from enum import Enum from typing import Dict from pyramid.api.services.configuration import IConfigurationService -from pyramid.connector.deezer.downloader import DeezerDownloader -from pyramid.connector.deezer.search import DeezerSearch +from pyramid.api.services.deezer_downloader import IDeezerDownloaderService +from pyramid.api.services.deezer_search import IDeezerSearchService +from pyramid.api.services.source_service import ISourceService +from pyramid.api.services.spotify_search import ISpotifySearchService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector from pyramid.connector.spotify.search import SpotifySearch from pyramid.data.a_search import ASearch from pyramid.data.exceptions import EngineSourceNotFoundException, TrackNotFoundException +from pyramid.data.source_type import SourceType from pyramid.data.track import Track, TrackMinimal, TrackMinimalDeezer -from pyramid.tools.configuration.configuration import Configuration -class SourceType(Enum): - Spotify = 1 - Deezer = 2 +@pyramid_service(interface=ISourceService) +class SourceService(ISourceService, ServiceInjector): + def injectService(self, + downloader_service: IDeezerDownloaderService, + deezer_search_service: IDeezerSearchService, + spotify_search_service: ISpotifySearchService, + + ): + self.__downloader = downloader_service + self.__deezer_search = deezer_search_service + self.__spotify_search = spotify_search_service -class EngineSource: - def __init__(self, config: IConfigurationService): - self.__downloader = DeezerDownloader(config.deezer__folder, config.deezer__arl) - self.__deezer_search = DeezerSearch(config.general__limit_tracks) - self.__spotify_search = SpotifySearch( - config.general__limit_tracks, config.spotify__client_id, config.spotify__client_secret - ) + def start(self): self.__default_source: ASearch = self.__deezer_search self.__downloader_source = self.__deezer_search self.__sources: Dict[SourceType, ASearch] = dict( diff --git a/src/pyramid/services/spotify_client.py b/src/pyramid/services/spotify_client.py new file mode 100644 index 0000000..b4564db --- /dev/null +++ b/src/pyramid/services/spotify_client.py @@ -0,0 +1,138 @@ +import json +import logging +from typing import Any +from venv import logger + +import aiohttp +from spotipy import Spotify +from spotipy.exceptions import SpotifyException + +from pyramid.api.services.spotify_client import ISpotifyClientService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector + + +@pyramid_service(interface=ISpotifyClientService) +class SpotifyClientService(Spotify, ISpotifyClientService, ServiceInjector): + + async def async_search(self, q, limit=10, offset=0, type="track", market=None) -> dict[str, Any]: + return await self._get_async( + "search", q=q, limit=limit, offset=offset, type=type, market=market + ) + + async def async_track(self, track_id, market=None) -> dict[str, Any]: + trid = self._get_id("track", track_id) + return await self._get_async("tracks/" + trid, market=market) + + async def async_playlist_items( + self, + playlist_id, + fields=None, + limit=100, + offset=0, + market=None, + additional_types=("track", "episode"), + ) -> dict[str, Any]: + plid = self._get_id("playlist", playlist_id) + return await self._get_async( + "playlists/%s/tracks" % (plid), + limit=limit, + offset=offset, + fields=fields, + market=market, + additional_types=",".join(additional_types), + ) + + async def async_album_tracks(self, album_id, limit=50, offset=0, market=None) -> dict[str, Any]: + trid = self._get_id("album", album_id) + return await self._get_async( + "albums/" + trid + "/tracks/", limit=limit, offset=offset, market=market + ) + + async def async_artist_top_tracks(self, artist_id, country="US") -> dict[str, Any]: + trid = self._get_id("artist", artist_id) + return await self._get_async("artists/" + trid + "/top-tracks", country=country) + + async def async_next(self, result) -> dict[str, Any] | None: + if not result["next"]: + return None + return await self._get_async(result["next"]) + + async def _get_async(self, url, args=None, payload=None, **kwargs) -> dict[str, Any]: + if args: + kwargs.update(args) + + return await self._async_internal_call("GET", url, payload, kwargs) + + async def _async_internal_call(self, method: str, url: str, payload, params) -> dict[str, Any]: + args = dict(params=params) + if not url.startswith("http"): + url = self.prefix + url + headers = self._auth_headers() + + if "content_type" in args["params"]: + headers["Content-Type"] = args["params"]["content_type"] + del args["params"]["content_type"] + if payload: + args["data"] = payload + else: + headers["Content-Type"] = "application/json" + if payload: + args["data"] = json.dumps(payload) + + if self.language is not None: + headers["Accept-Language"] = self.language + + params = ( + {key: value for key, value in args["params"].items() if value is not None} + if "params" in args + else dict() + ) + logging.debug( + "Sending %s to %s with Params: %s Headers: %s and Body: %r ", + method, + url, + params, + headers, + args.get("data"), + ) + async with aiohttp.ClientSession() as session: + async with session.request( + method, + url, + headers=headers, + proxy=self.proxies, + timeout=self.requests_timeout, + params=params, + ) as response: + try: + response.raise_for_status() + results = await response.json() + except aiohttp.ClientResponseError: + try: + json_response = await response.json() + error = json_response.get("error", {}) + msg = error.get("message") + reason = error.get("reason") + except json.JSONDecodeError: + msg = await response.text() or None + reason = None + + logger.error( + "HTTP Error for %s to %s with Params: %s returned %s due to %s", + method, + url, + args.get("params"), + response.status, + msg, + ) + raise SpotifyException( + response.status, + -1, + "%s:\n %s" % (response.url, msg), + reason=reason, + headers=response.headers, + ) + + logger.debug("RESULTS: %s", results) + return results diff --git a/src/pyramid/services/spotify_search.py b/src/pyramid/services/spotify_search.py new file mode 100644 index 0000000..186a3f8 --- /dev/null +++ b/src/pyramid/services/spotify_search.py @@ -0,0 +1,116 @@ +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.spotify_client import ISpotifyClientService +from pyramid.api.services.spotify_search import ISpotifySearchService +from pyramid.api.services.spotify_search_base import ISpotifySearchBaseService +from pyramid.api.services.spotify_search_id import ISpotifySearchIdService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.connector.spotify.spotify_tools import SpotifyTools +from pyramid.connector.spotify.spotify_type import SpotifyType +from pyramid.data.track import TrackMinimalSpotify + + +@pyramid_service(interface=ISpotifySearchService) +class SpotifySearchService(ISpotifySearchService, ServiceInjector): + + def injectService(self, + configuration_service: IConfigurationService, + spotify_client: ISpotifyClientService, + spotify_search_base: ISpotifySearchBaseService, + spotify_search_id: ISpotifySearchIdService + ): + self.__configuration_service = configuration_service + self.__spotify_client = spotify_client + self.__spotify_search_base = spotify_search_base + self.__spotify_search_id = spotify_search_id + + async def search_tracks( + self, search, limit: int | None = None + ) -> list[TrackMinimalSpotify] | None: + if limit is None: + limit = self.__configuration_service.general__limit_tracks + if limit > 50: + req_limit = 50 + else: + req_limit = limit + results = await self.__spotify_client.async_search(q=search, limit=req_limit, type="track") + tracks = await self.__spotify_search_base.items_max(results, limit) + if not tracks: + return None + return [TrackMinimalSpotify(element) for element in tracks] + + async def search_track(self, search) -> TrackMinimalSpotify | None: + results = await self.__spotify_client.async_search(q=search, limit=1, type="track") + + if not results or not results.get("tracks") or not results["tracks"].get("items"): + return None + + tracks = results["tracks"]["items"] + track = tracks[0] + + return TrackMinimalSpotify(track) + + async def get_playlist_tracks(self, playlist_name) -> list[TrackMinimalSpotify] | None: + results = await self.__spotify_client.async_search(q=playlist_name, limit=1, type="playlist") + + if not results or not results.get("playlists") or not results["playlists"].get("items"): + return None + + playlist_id = results["playlists"]["items"][0]["id"] + tracks = await self.__spotify_search_id.get_playlist_tracks_by_id(playlist_id) + if not tracks: + return None + return tracks[0] + + async def get_album_tracks(self, album_name) -> list[TrackMinimalSpotify] | None: + results = await self.__spotify_client.async_search(q=album_name, limit=1, type="album") + + if not results or not results.get("albums") or not results["albums"].get("items"): + return None + + album_id = results["albums"]["items"][0]["id"] + tracks = await self.__spotify_search_id.get_album_tracks_by_id(album_id) + if not tracks: + return None + return tracks[0] + + async def get_top_artist( + self, artist_name, limit: int | None = None + ) -> list[TrackMinimalSpotify] | None: + results = await self.__spotify_client.async_search(q=artist_name, limit=1, type="artist") + + if not results or not results.get("artists") or not results["artists"].get("items"): + return None + + artist_id = results["artists"]["items"][0]["id"] + tracks = await self.__spotify_search_id.get_top_artist_by_id(artist_id) + if not tracks: + return None + return tracks[0] + + async def get_by_url( + self, url + ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | TrackMinimalSpotify | None: + id, type = SpotifyTools.extract_from_url(url) + + if id is None: + return None + if type is None: + raise NotImplementedError(f"The type of spotify info '{url}' is not implemented") + + tracks: ( + tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | TrackMinimalSpotify | None + ) + + if type == SpotifyType.PLAYLIST: + tracks = await self.__spotify_search_id.get_playlist_tracks_by_id(id) + elif type == SpotifyType.ARTIST: + tracks = await self.__spotify_search_id.get_top_artist_by_id(id) + elif type == SpotifyType.ALBUM: + tracks = await self.__spotify_search_id.get_album_tracks_by_id(id) + elif type == SpotifyType.TRACK: + tracks = await self.__spotify_search_id.get_track_by_id(id) + else: + raise NotImplementedError(f"The type of spotify info '{type}' can't be resolve") + + return tracks diff --git a/src/pyramid/services/spotify_search_base.py b/src/pyramid/services/spotify_search_base.py new file mode 100644 index 0000000..10e21f2 --- /dev/null +++ b/src/pyramid/services/spotify_search_base.py @@ -0,0 +1,65 @@ + + + +from typing import Any +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.spotify_search_base import ISpotifySearchBaseService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.connector.spotify.spotify_tools import SpotifyTools +from spotipy.oauth2 import SpotifyClientCredentials + +from pyramid.connector.spotify.cli_spotify import CliSpotify + +@pyramid_service(interface=ISpotifySearchBaseService) +class SpotifySearchBaseService(ISpotifySearchBaseService, ServiceInjector): + + def injectService(self, + configuration_service: IConfigurationService, + ): + self.__configuration_service = configuration_service + + def start(self): + self.client_credentials_manager = SpotifyClientCredentials( + client_id=self.__configuration_service.spotify__client_id, + client_secret=self.__configuration_service.spotify__client_secret + ) + self.client = CliSpotify(client_credentials_manager=self.client_credentials_manager) + + async def items( + self, + results: dict[str, Any], + item_name="items" + ) -> list[dict[str, Any]] | None: + if not results: + return None + tracks: list = results[item_name] + + while results["next"]: + results = await self.client.async_next(results) # type: ignore + tracks.extend(results[item_name]) + + return tracks + + async def items_max( + self, + results: dict[str, Any], + limit: int | None = None, + item_name="items" + ) -> list[Any] | None: + if not results or not results.get("tracks") or not results["tracks"].get(item_name): + return None + + if limit is None: + limit = self.__configuration_service.general__limit_tracks + tracks: list[Any] = results["tracks"][item_name] + + results_tracks: dict[str, Any] = results["tracks"] + while results["tracks"]["next"] and limit > len(tracks): + results = await self.client.async_next(results_tracks) # type: ignore + tracks.extend(results_tracks[item_name]) + + if len(tracks) > limit: + return tracks[:limit] + + return tracks diff --git a/src/pyramid/services/spotify_search_id.py b/src/pyramid/services/spotify_search_id.py new file mode 100644 index 0000000..4d6fccf --- /dev/null +++ b/src/pyramid/services/spotify_search_id.py @@ -0,0 +1,71 @@ +from typing import Any + +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.spotify_client import ISpotifyClientService +from pyramid.api.services.spotify_search_base import ISpotifySearchBaseService +from pyramid.api.services.spotify_search_id import ISpotifySearchIdService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.data.track import TrackMinimalSpotify + +@pyramid_service(interface=ISpotifySearchIdService) +class SpotifySearchIdService(ISpotifySearchIdService, ServiceInjector): + + def injectService(self, + configuration_service: IConfigurationService, + spotify_client: ISpotifyClientService, + spotify_search_base: ISpotifySearchBaseService + ): + self.__configuration_service = configuration_service + self.__spotify_client = spotify_client + self.__spotify_search_base = spotify_search_base + + async def get_track_by_id(self, track_id: str) -> TrackMinimalSpotify | None: + result = await self.__spotify_client.async_track(track_id=track_id) + if not result: + return None + return TrackMinimalSpotify(result) + + async def get_playlist_tracks_by_id( + self, playlist_id: str + ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | None: + tracks_playlist = await self.__spotify_search_base.items( + await self.__spotify_client.async_playlist_items(playlist_id=playlist_id) + ) + if not tracks_playlist: + return None + return [TrackMinimalSpotify(element["track"]) for element in tracks_playlist], [] + + async def get_album_tracks_by_id( + self, album_id: str + ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | None: + tracks = await self.__spotify_search_base.items( + await self.__spotify_client.async_album_tracks(album_id=album_id) + ) + if not tracks: + return None + + readable_tracks = [] + unreadable_tracks = [] + for t in tracks: + track = await self.get_track_by_id(t["id"]) + if track is None: + unreadable_tracks.append(t) + else: + readable_tracks.append(track) + return readable_tracks, unreadable_tracks + + async def get_top_artist_by_id( + self, artist_id: str, limit: int | None = None + ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | None: + if limit is None: + limit = self.__configuration_service.general__limit_tracks + results = await self.__spotify_client.async_artist_top_tracks(artist_id) + + if not results or not results.get("tracks"): + return None + + tracks = results["tracks"] + if len(tracks) > limit: + tracks = tracks[:limit] + return [TrackMinimalSpotify(element) for element in tracks], [] \ No newline at end of file From 77aefba40ef8e6ac7be8d6e53cd4e39426ae9fbe Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Sat, 21 Sep 2024 02:18:29 +0200 Subject: [PATCH 13/32] feat: make deezer client a service, split all files that have multiple class --- src/pyramid/api/services/deezer_client.py | 78 ++++++ src/pyramid/api/services/deezer_downloader.py | 2 +- src/pyramid/api/services/deezer_search.py | 2 +- src/pyramid/api/services/source_service.py | 4 +- src/pyramid/api/services/spotify_search.py | 2 +- src/pyramid/api/services/spotify_search_id.py | 2 +- src/pyramid/connector/deezer/cli_deezer.py | 262 ++---------------- .../deezer/client/a_deezer_client.py | 19 ++ .../connector/deezer/client/exceptions.py | 60 ++++ .../connector/deezer/client/list_paginated.py | 132 +++++++++ .../deezer/client/rate_limiter_async.py | 31 +++ .../{py_deezer.py => download/client.py} | 89 +----- .../connector/deezer/download/decrypt.py | 88 ++++++ .../connector/deezer/download/exception.py | 7 + src/pyramid/connector/deezer/downloader.py | 5 +- src/pyramid/connector/deezer/search.py | 29 +- src/pyramid/connector/deezer/tools.py | 4 +- src/pyramid/connector/discord/guild_cmd.py | 2 +- .../connector/discord/guild_cmd_tools.py | 4 +- src/pyramid/connector/discord/guild_queue.py | 4 +- .../connector/discord/music/player_button.py | 35 +++ .../player_interface.py} | 37 +-- src/pyramid/connector/spotify/search.py | 2 +- src/pyramid/data/a_guid_queue.py | 2 +- src/pyramid/data/a_search.py | 2 +- src/pyramid/data/guild_instance.py | 2 +- src/pyramid/data/music/track.py | 38 +++ src/pyramid/data/music/track_minimal.py | 35 +++ .../data/music/track_minimal_deezer.py | 25 ++ .../data/music/track_minimal_spotify.py | 25 ++ src/pyramid/data/track.py | 118 -------- src/pyramid/data/tracklist.py | 4 +- src/pyramid/services/deezer_client.py | 201 ++++++++++++++ src/pyramid/services/deezer_downloader.py | 4 +- src/pyramid/services/deezer_search.py | 57 ++-- src/pyramid/services/source.py | 8 +- src/pyramid/services/spotify_client.py | 15 +- src/pyramid/services/spotify_search.py | 2 +- src/pyramid/services/spotify_search_base.py | 17 +- src/pyramid/services/spotify_search_id.py | 2 +- 40 files changed, 889 insertions(+), 568 deletions(-) create mode 100644 src/pyramid/api/services/deezer_client.py create mode 100644 src/pyramid/connector/deezer/client/a_deezer_client.py create mode 100644 src/pyramid/connector/deezer/client/exceptions.py create mode 100644 src/pyramid/connector/deezer/client/list_paginated.py create mode 100644 src/pyramid/connector/deezer/client/rate_limiter_async.py rename src/pyramid/connector/deezer/{py_deezer.py => download/client.py} (79%) create mode 100644 src/pyramid/connector/deezer/download/decrypt.py create mode 100644 src/pyramid/connector/deezer/download/exception.py create mode 100644 src/pyramid/connector/discord/music/player_button.py rename src/pyramid/connector/discord/{music_player_interface.py => music/player_interface.py} (64%) create mode 100644 src/pyramid/data/music/track.py create mode 100644 src/pyramid/data/music/track_minimal.py create mode 100644 src/pyramid/data/music/track_minimal_deezer.py create mode 100644 src/pyramid/data/music/track_minimal_spotify.py delete mode 100644 src/pyramid/data/track.py create mode 100644 src/pyramid/services/deezer_client.py diff --git a/src/pyramid/api/services/deezer_client.py b/src/pyramid/api/services/deezer_client.py new file mode 100644 index 0000000..f925c47 --- /dev/null +++ b/src/pyramid/api/services/deezer_client.py @@ -0,0 +1,78 @@ +from abc import abstractmethod + +from deezer import Album, Artist, Playlist, Track +from pyramid.connector.deezer.client.a_deezer_client import ADeezerClient +from pyramid.connector.deezer.client.list_paginated import DeezerListPaginated + +class IDeezerClientService(ADeezerClient): + + @abstractmethod + def _search( + self, + path: str, + query: str = "", + strict: bool | None = None, + ordering: str | None = None, + **advanced_params: str | int | None, + ) -> DeezerListPaginated: + pass + + @abstractmethod + def search( + self, + query: str = "", + strict: bool | None = None, + ordering: str | None = None, + artist: str | None = None, + album: str | None = None, + track: str | None = None, + label: str | None = None, + dur_min: int | None = None, + dur_max: int | None = None, + bpm_min: int | None = None, + bpm_max: int | None = None, + ) -> DeezerListPaginated: + pass + + @abstractmethod + def search_playlists( + self, + query: str = "", + strict: bool | None = None, + ordering: str | None = None, + ) -> DeezerListPaginated[Playlist]: + pass + + @abstractmethod + def search_albums( + self, + query: str = "", + strict: bool | None = None, + ordering: str | None = None, + ) -> DeezerListPaginated[Album]: + pass + + @abstractmethod + def search_artists( + self, + query: str = "", + strict: bool | None = None, + ordering: str | None = None, + ) -> DeezerListPaginated[Artist]: + pass + + @abstractmethod + async def async_get_playlist(self, playlist_id: int) -> Playlist: + pass + + @abstractmethod + async def async_get_album(self, album_id: int) -> Album: + pass + + @abstractmethod + async def async_get_artist(self, artist_id: int) -> Artist: + pass + + @abstractmethod + async def async_get_track(self, track_id: int) -> Track: + pass diff --git a/src/pyramid/api/services/deezer_downloader.py b/src/pyramid/api/services/deezer_downloader.py index 396bb9f..9e9b1c6 100644 --- a/src/pyramid/api/services/deezer_downloader.py +++ b/src/pyramid/api/services/deezer_downloader.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from typing import Any -from pyramid.data.track import Track +from pyramid.data.music.track import Track class IDeezerDownloaderService(ABC): diff --git a/src/pyramid/api/services/deezer_search.py b/src/pyramid/api/services/deezer_search.py index cfeeeca..6570c63 100644 --- a/src/pyramid/api/services/deezer_search.py +++ b/src/pyramid/api/services/deezer_search.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from pyramid.data.a_search import ASearch -from pyramid.data.track import TrackMinimalDeezer +from pyramid.data.music.track_minimal_deezer import TrackMinimalDeezer class IDeezerSearchService(ASearch): diff --git a/src/pyramid/api/services/source_service.py b/src/pyramid/api/services/source_service.py index 42c1fe4..cf07bb7 100644 --- a/src/pyramid/api/services/source_service.py +++ b/src/pyramid/api/services/source_service.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from pyramid.data.source_type import SourceType -from pyramid.data.track import Track, TrackMinimal, TrackMinimalDeezer +from pyramid.data.music.track import Track +from pyramid.data.music.track_minimal_deezer import TrackMinimalDeezer +from pyramid.data.music.track_minimal import TrackMinimal class ISourceService(ABC): diff --git a/src/pyramid/api/services/spotify_search.py b/src/pyramid/api/services/spotify_search.py index ee7d06f..c6c2f1b 100644 --- a/src/pyramid/api/services/spotify_search.py +++ b/src/pyramid/api/services/spotify_search.py @@ -1,6 +1,6 @@ from abc import abstractmethod from pyramid.data.a_search import ASearch -from pyramid.data.track import TrackMinimalSpotify +from pyramid.data.music.track_minimal_spotify import TrackMinimalSpotify class ISpotifySearchService(ASearch): diff --git a/src/pyramid/api/services/spotify_search_id.py b/src/pyramid/api/services/spotify_search_id.py index 373edcd..97af150 100644 --- a/src/pyramid/api/services/spotify_search_id.py +++ b/src/pyramid/api/services/spotify_search_id.py @@ -1,7 +1,7 @@ from abc import abstractmethod from pyramid.data.a_search import ASearchId -from pyramid.data.track import TrackMinimalSpotify +from pyramid.data.music.track_minimal_spotify import TrackMinimalSpotify class ISpotifySearchIdService(ASearchId): diff --git a/src/pyramid/connector/deezer/cli_deezer.py b/src/pyramid/connector/deezer/cli_deezer.py index 500f27d..63f22eb 100644 --- a/src/pyramid/connector/deezer/cli_deezer.py +++ b/src/pyramid/connector/deezer/cli_deezer.py @@ -1,185 +1,16 @@ -import asyncio -import time -from abc import ABC, abstractmethod -from typing import Any, Generic, Literal, Self -from urllib.parse import parse_qs, urlparse +from typing import Any import aiohttp from deezer import Album, Artist, Client, Playlist, Resource, Track -from deezer.exceptions import DeezerAPIException, DeezerErrorResponse -from deezer.pagination import ResourceType +from pyramid.connector.deezer.client.a_deezer_client import ADeezerClient +from pyramid.connector.deezer.client.exceptions import CliDeezerErrorResponse, CliDeezerHTTPError +from pyramid.connector.deezer.client.list_paginated import DeezerListPaginated +from pyramid.connector.deezer.client.rate_limiter_async import RateLimiterAsync +from pyramid.tools.deprecated_class import deprecated_class -class AsyncRateLimiter: - def __init__(self, max_requests, time_interval): - self.max_requests = max_requests - self.time_interval = time_interval - self.requests: list[float] = [] - self.lock = asyncio.Lock() - - def _clean_old_requests(self): - current_time = time.time() - self.requests = [t for t in self.requests if self.time_interval > current_time - t] - - async def _wait_if_needed(self): - async with self.lock: - self._clean_old_requests() - if len(self.requests) >= self.max_requests: - sleep_time = self.requests[0] + self.time_interval - time.time() - if sleep_time > 0: - # logging.warning("Detect Deezer RateLimit - wait %f secs", sleep_time) - await asyncio.sleep(sleep_time) - self._clean_old_requests() - - async def check(self): - await self._wait_if_needed() - - async def add(self): - async with self.lock: - self.requests.append(time.time()) - - -class ACliDeezer(ABC): - @abstractmethod - async def async_request( - self, - method: str, - path: str, - parent: Resource | None = None, - resource_type: type[Resource] | None = None, - resource_id: int | None = None, - paginate_list=False, - **params, - ) -> Any | list[Any] | dict[str, Any] | Resource | Literal[True]: - ... - - -class CliPaginatedList(Generic[ResourceType]): - def __init__( - self, - client: ACliDeezer, - base_path: str, - parent: Resource | None = None, - **params, - ): - self.__client = client - self.__base_path = base_path - self.__base_params: dict[str, Any] = params - self.__next_path: str | None = base_path - self.__next_params: dict[str, Any] = params - self.__parent = parent - self.__total: int | None = None - - async def get_first(self) -> ResourceType | None: - return await self.get_single(0) - - async def get_single(self, index: int) -> ResourceType | None: - if 0 > index: - raise ValueError("index can't be less than 0. (%d received)", index) - - elements = await self._req(False, {"index": index, "limit": 1}) - - if index >= len(elements): - return None - - return elements[0] - - async def get_maximum(self, limit: int): - if 0 > limit: - raise ValueError("limit can't be less than 0. (%d received)", limit) - if limit > 100: - raise ValueError("You need to use async iterator when limit is bigger than %d", limit) - elements = await self._req(False, {"limit": limit}) - - return elements - - async def get_all(self) -> list[ResourceType]: - elements: list[ResourceType] = [] - - while self._could_grow(): - elements.extend(await self._async_fetch_next_page()) - - return elements - - async def get_page(self, item_per_page: int, page: int): - if 0 > item_per_page: - raise ValueError("item_per_page can't be less than 0. (%d received)", item_per_page) - if item_per_page > 100: - raise NotImplementedError( - "Pagination with more than 100 items per page is not supported yet." - ) - - index = (page - 1) * item_per_page - elements = await self._req(False, {"index": index, "limit": item_per_page}) - return elements - - async def total(self) -> int: - if self.__total is not None: - return self.__total - - await self._req(False, {"index": 0, "limit": 1}) - assert self.__total is not None - return self.__total - - def __aiter__(self): - return self - - async def __anext__(self) -> list[ResourceType]: - if not self._could_grow(): - raise StopAsyncIteration - - elements = await self._async_fetch_next_page() - return elements - - def _could_grow(self) -> bool: - return self.__next_path is not None - - async def _async_fetch_next_page(self) -> list[ResourceType]: - return await self._req(True) - - async def _req( - self, use_iterator: bool = False, custom_params: dict[str, Any] | None = None - ) -> list[ResourceType]: - if use_iterator: - assert self.__next_path is not None - path = self.__next_path - params = self.__next_params - else: - path = self.__base_path - params = self.__base_params - - if custom_params is not None: - params = params.copy() - params.update(custom_params) - - response_payload: dict[str, Any] = await self.__client.async_request( - "GET", - path, - parent=self.__parent, - paginate_list=True, - resource_type=None, - resource_id=None, - **params, - ) # type: ignore - - if self.__total is None: - self.__total = response_payload.get("total") - - if use_iterator: - next_url: str = response_payload.get("next", None) - if next_url: - url_bits = urlparse(next_url) - self.__next_path = url_bits.path.lstrip("/") - self.__next_params = parse_qs(url_bits.query) - else: - self.__next_path = None - - elements: list[ResourceType] = response_payload["data"] - - return elements - - -class CliDeezer(ACliDeezer, Client): +@deprecated_class +class CliDeezer(ADeezerClient, Client): def __init__(self, app_id=None, app_secret=None, access_token=None, headers=None, **kwargs): # super().__init__(app_id, app_secret, access_token, headers, **kwargs) @@ -192,14 +23,14 @@ def __init__(self, app_id=None, app_secret=None, access_token=None, headers=None # self.session.headers.update(headers) # self.session.close() # self.async_session = aiohttp.ClientSession() - self.rate_limiter = AsyncRateLimiter(max_requests=50, time_interval=5) + self.rate_limiter = RateLimiterAsync(max_requests=50, time_interval=5) def get_paginated_list( self, relation: str, **kwargs, - ) -> CliPaginatedList: - return CliPaginatedList( + ) -> DeezerListPaginated: + return DeezerListPaginated( client=self.client, base_path=f"{self.type}/{self.id}/{relation}", parent=self, @@ -227,7 +58,7 @@ def _search( strict: bool | None = None, ordering: str | None = None, **advanced_params: str | int | None, - ) -> CliPaginatedList: + ) -> DeezerListPaginated: return super()._search(path, query, strict, ordering, **advanced_params) # type: ignore def search( @@ -243,7 +74,7 @@ def search( dur_max: int | None = None, bpm_min: int | None = None, bpm_max: int | None = None, - ) -> CliPaginatedList: + ) -> DeezerListPaginated: return self._search( "", query=query, @@ -264,7 +95,7 @@ def search_playlists( query: str = "", strict: bool | None = None, ordering: str | None = None, - ) -> CliPaginatedList[Playlist]: + ) -> DeezerListPaginated[Playlist]: return self._search( path="playlist", query=query, @@ -277,7 +108,7 @@ def search_albums( query: str = "", strict: bool | None = None, ordering: str | None = None, - ) -> CliPaginatedList[Album]: + ) -> DeezerListPaginated[Album]: return self._search( path="album", query=query, @@ -290,7 +121,7 @@ def search_artists( query: str = "", strict: bool | None = None, ordering: str | None = None, - ) -> CliPaginatedList[Artist]: + ) -> DeezerListPaginated[Artist]: return self._search( path="artist", query=query, @@ -310,8 +141,8 @@ async def async_get_artist(self, artist_id: int) -> Artist: async def async_get_track(self, track_id: int) -> Track: return await self.async_request("GET", f"track/{track_id}") # type: ignore - def _get_paginated_list(self, path, **params) -> CliPaginatedList: - return CliPaginatedList(client=self, base_path=path, **params) + def _get_paginated_list(self, path, **params) -> DeezerListPaginated: + return DeezerListPaginated(client=self, base_path=path, **params) async def async_request( self, @@ -366,60 +197,3 @@ async def async_request( resource_id=resource_id, paginate_list=paginate_list, ) - - -class CliDeezerHTTPError(DeezerAPIException): - """Specialisation wrapping HTTPError from the requests library.""" - - def __init__(self, http_exception: aiohttp.ClientResponseError, *args: object) -> None: - url = http_exception.request_info.url - status_code = http_exception.code - text = http_exception.message - super().__init__(status_code, url, text, *args) - - @classmethod - def from_status_code(cls, exc: aiohttp.ClientResponseError) -> Self: - """Initialise the appropriate internal exception from a HTTPError.""" - if exc.code in {502, 503, 504}: - return CliDeezerRetryableHTTPError(exc) - if exc.code == 403: - return CliDeezerForbiddenError(exc) - if exc.code == 404: - return CliDeezerNotFoundError(exc) - return cls(exc) - - -class CliDeezerRetryableException(DeezerAPIException): - """A request failing with this might work if retried.""" - - -class CliDeezerRetryableHTTPError(CliDeezerRetryableException, CliDeezerHTTPError): - """A HTTP error due to a potentially temporary issue.""" - - -class CliDeezerForbiddenError(CliDeezerHTTPError): - """A HTTP error cause by permission denied error.""" - - -class CliDeezerNotFoundError(CliDeezerHTTPError): - """For 404 HTTP errors.""" - - -class CliDeezerErrorResponse(DeezerErrorResponse): - @classmethod - def from_body(cls, json_data: dict[str, Any]) -> Self: - err_json = json_data["error"] - code = int(err_json["code"]) - if code == 4: - return CliDeezerRateLimitError(json_data) - elif code == 800: - return CliDeezerNoDataException(json_data) - return cls(json_data) - - -class CliDeezerRateLimitError(CliDeezerErrorResponse): - pass - - -class CliDeezerNoDataException(CliDeezerErrorResponse): - pass diff --git a/src/pyramid/connector/deezer/client/a_deezer_client.py b/src/pyramid/connector/deezer/client/a_deezer_client.py new file mode 100644 index 0000000..9967199 --- /dev/null +++ b/src/pyramid/connector/deezer/client/a_deezer_client.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod +from typing import Any, Literal + +from deezer import Resource + + +class ADeezerClient(ABC): + @abstractmethod + async def async_request( + self, + method: str, + path: str, + parent: Resource | None = None, + resource_type: type[Resource] | None = None, + resource_id: int | None = None, + paginate_list=False, + **params, + ) -> Any | list[Any] | dict[str, Any] | Resource | Literal[True]: + ... diff --git a/src/pyramid/connector/deezer/client/exceptions.py b/src/pyramid/connector/deezer/client/exceptions.py new file mode 100644 index 0000000..e12ab9f --- /dev/null +++ b/src/pyramid/connector/deezer/client/exceptions.py @@ -0,0 +1,60 @@ + +from typing import Any, Self +import aiohttp +from deezer.exceptions import DeezerAPIException, DeezerErrorResponse + +class CliDeezerHTTPError(DeezerAPIException): + """Specialisation wrapping HTTPError from the requests library.""" + + def __init__(self, http_exception: aiohttp.ClientResponseError, *args: object) -> None: + url = http_exception.request_info.url + status_code = http_exception.code + text = http_exception.message + super().__init__(status_code, url, text, *args) + + @classmethod + def from_status_code(cls, exc: aiohttp.ClientResponseError) -> DeezerAPIException: + """Initialise the appropriate internal exception from a HTTPError.""" + if exc.code in {502, 503, 504}: + return CliDeezerRetryableHTTPError(exc) + if exc.code == 403: + return CliDeezerForbiddenError(exc) + if exc.code == 404: + return CliDeezerNotFoundError(exc) + return cls(exc) + + +class CliDeezerRetryableException(DeezerAPIException): + """A request failing with this might work if retried.""" + + +class CliDeezerRetryableHTTPError(CliDeezerRetryableException, CliDeezerHTTPError): + """A HTTP error due to a potentially temporary issue.""" + + +class CliDeezerForbiddenError(CliDeezerHTTPError): + """A HTTP error cause by permission denied error.""" + + +class CliDeezerNotFoundError(CliDeezerHTTPError): + """For 404 HTTP errors.""" + + +class CliDeezerErrorResponse(DeezerErrorResponse): + @classmethod + def from_body(cls, json_data: dict[str, Any]) -> DeezerErrorResponse: + err_json = json_data["error"] + code = int(err_json["code"]) + if code == 4: + return CliDeezerRateLimitError(json_data) + elif code == 800: + return CliDeezerNoDataException(json_data) + return cls(json_data) + + +class CliDeezerRateLimitError(DeezerErrorResponse): + pass + + +class CliDeezerNoDataException(DeezerErrorResponse): + pass diff --git a/src/pyramid/connector/deezer/client/list_paginated.py b/src/pyramid/connector/deezer/client/list_paginated.py new file mode 100644 index 0000000..bcd923c --- /dev/null +++ b/src/pyramid/connector/deezer/client/list_paginated.py @@ -0,0 +1,132 @@ +from typing import Any, Generic +from urllib.parse import parse_qs, urlparse + +from deezer import Resource +from deezer.pagination import ResourceType + +from pyramid.connector.deezer.client.a_deezer_client import ADeezerClient + + +class DeezerListPaginated(Generic[ResourceType]): + def __init__( + self, + client: ADeezerClient, + base_path: str, + parent: Resource | None = None, + **params, + ): + self.__client = client + self.__base_path = base_path + self.__base_params: dict[str, Any] = params + self.__next_path: str | None = base_path + self.__next_params: dict[str, Any] = params + self.__parent = parent + self.__total: int | None = None + + async def get_first(self) -> ResourceType | None: + return await self.get_single(0) + + async def get_single(self, index: int) -> ResourceType | None: + if 0 > index: + raise ValueError("index can't be less than 0. (%d received)", index) + + elements = await self._req(False, {"index": index, "limit": 1}) + + if index >= len(elements): + return None + + return elements[0] + + async def get_maximum(self, limit: int): + if 0 > limit: + raise ValueError("limit can't be less than 0. (%d received)", limit) + if limit > 100: + raise ValueError("You need to use async iterator when limit is bigger than %d", limit) + elements = await self._req(False, {"limit": limit}) + + return elements + + async def get_all(self) -> list[ResourceType]: + elements: list[ResourceType] = [] + + while self._could_grow(): + elements.extend(await self._async_fetch_next_page()) + + return elements + + async def get_page(self, item_per_page: int, page: int): + if 0 > item_per_page: + raise ValueError("item_per_page can't be less than 0. (%d received)", item_per_page) + if item_per_page > 100: + raise NotImplementedError( + "Pagination with more than 100 items per page is not supported yet." + ) + + index = (page - 1) * item_per_page + elements = await self._req(False, {"index": index, "limit": item_per_page}) + return elements + + async def total(self) -> int: + if self.__total is not None: + return self.__total + + await self._req(False, {"index": 0, "limit": 1}) + assert self.__total is not None + return self.__total + + def __aiter__(self): + return self + + async def __anext__(self) -> list[ResourceType]: + if not self._could_grow(): + raise StopAsyncIteration + + elements = await self._async_fetch_next_page() + return elements + + def _could_grow(self) -> bool: + return self.__next_path is not None + + async def _async_fetch_next_page(self) -> list[ResourceType]: + return await self._req(True) + + async def _req( + self, use_iterator: bool = False, custom_params: dict[str, Any] | None = None + ) -> list[ResourceType]: + if use_iterator: + assert self.__next_path is not None + path = self.__next_path + params = self.__next_params + else: + path = self.__base_path + params = self.__base_params + + if custom_params is not None: + params = params.copy() + params.update(custom_params) + + response_payload: dict[str, Any] = await self.__client.async_request( + "GET", + path, + parent=self.__parent, + paginate_list=True, + resource_type=None, + resource_id=None, + **params, + ) # type: ignore + + if self.__total is None: + self.__total = response_payload.get("total") + + if use_iterator: + next_url: str = response_payload.get("next", None) + if next_url: + url_bits = urlparse(next_url) + self.__next_path = url_bits.path.lstrip("/") + self.__next_params = parse_qs(url_bits.query) + else: + self.__next_path = None + + elements: list[ResourceType] = response_payload["data"] + + return elements diff --git a/src/pyramid/connector/deezer/client/rate_limiter_async.py b/src/pyramid/connector/deezer/client/rate_limiter_async.py new file mode 100644 index 0000000..09ed699 --- /dev/null +++ b/src/pyramid/connector/deezer/client/rate_limiter_async.py @@ -0,0 +1,31 @@ +import asyncio +import time + + +class RateLimiterAsync: + def __init__(self, max_requests, time_interval): + self.max_requests = max_requests + self.time_interval = time_interval + self.requests: list[float] = [] + self.lock = asyncio.Lock() + + def _clean_old_requests(self): + current_time = time.time() + self.requests = [t for t in self.requests if self.time_interval > current_time - t] + + async def _wait_if_needed(self): + async with self.lock: + self._clean_old_requests() + if len(self.requests) >= self.max_requests: + sleep_time = self.requests[0] + self.time_interval - time.time() + if sleep_time > 0: + # logging.warning("Detect Deezer RateLimit - wait %f secs", sleep_time) + await asyncio.sleep(sleep_time) + self._clean_old_requests() + + async def check(self): + await self._wait_if_needed() + + async def add(self): + async with self.lock: + self.requests.append(time.time()) diff --git a/src/pyramid/connector/deezer/py_deezer.py b/src/pyramid/connector/deezer/download/client.py similarity index 79% rename from src/pyramid/connector/deezer/py_deezer.py rename to src/pyramid/connector/deezer/download/client.py index bdcd880..4ad7785 100644 --- a/src/pyramid/connector/deezer/py_deezer.py +++ b/src/pyramid/connector/deezer/download/client.py @@ -4,103 +4,20 @@ import warnings from os import path -import aiofiles import aiohttp from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from pyramid.data.exceptions import CustomException, DeezerTokenInvalidException +from pyramid.data.exceptions import DeezerTokenInvalidException -with warnings.catch_warnings(): - warnings.simplefilter("ignore") - from cryptography.hazmat.primitives.ciphers.algorithms import Blowfish +from pyramid.connector.deezer.download.decrypt import DecryptDeezer +from pyramid.connector.deezer.download.exception import DlDeezerNotUrlFoundException from pydeezer import Deezer, util from pydeezer.constants import api_methods, api_urls, networking_settings, track_formats from pydeezer.exceptions import APIRequestError, DownloadLinkDecryptionError from pydeezer.ProgressHandler import BaseProgressHandler, DefaultProgressHandler - -class DlDeezerNotUrlFoundException(CustomException): - pass - - -class DecryptDeezer: - def __init__(self, blowfish_key: bytes, progress_handler: BaseProgressHandler) -> None: - self.chunk_length = 6144 - self.decrypt_chunk_length = 2048 - self.cipher = Cipher( - Blowfish(blowfish_key), - modes.CBC(bytes([i for i in range(8)])), - default_backend(), - ) - self.progress_handler = progress_handler - - async def output_file(self, filesize: int, file_path: str, res: aiohttp.ClientResponse): - async with aiofiles.open(file_path, "wb") as f: - await f.seek(0) - - chunk_index = 0 - downloaded_size = 0 - previous_chunk = None - async for chunk, _ in res.content.iter_chunks(): - chunk_size = len(chunk) - self.progress_handler.update(current_chunk_size=chunk_size) - - downloaded_size += chunk_size - if previous_chunk: - previous_chunk, chunks_used = await self._transform_chunk( - f, previous_chunk + chunk - ) - else: - previous_chunk, chunks_used = await self._transform_chunk(f, chunk) - chunk_index += chunks_used - - if previous_chunk: - await self._write_file(f, previous_chunk) - if downloaded_size != filesize: - missing = filesize - downloaded_size - raise Exception("[%s] %d bytes are missing" % (filesize, missing)) - - async def _transform_chunk( - self, f: aiofiles.threadpool.binary.AsyncBufferedIOBase, bytes_chunked: bytes - ) -> tuple[bytes, int] | tuple[None, int]: - # Calculate the number of chunks needed - length_bytes = len(bytes_chunked) - chunks_nb = int(length_bytes / self.chunk_length) + 1 - - # Iterate over the chunks and call the callback for each one - for i in range(chunks_nb - 1): - chunk = bytes_chunked[i * self.chunk_length : (i + 1) * self.chunk_length] - await self._write_file(f, chunk) - - last_chunk_start = (chunks_nb - 1) * self.chunk_length - last_chunk = bytes_chunked[last_chunk_start:] - last_length = len(last_chunk) - - if last_length == self.chunk_length: - await self._write_file(f, last_chunk) - return None, chunks_nb - elif last_length < self.chunk_length: - return last_chunk, chunks_nb - 1 - raise Exception( - "Last chunk has wrong size %d (under %d is excepted)", last_length, self.chunk_length - ) - - async def _write_file( - self, f: aiofiles.threadpool.binary.AsyncBufferedIOBase, new_chunk: bytes - ): - chunk_size = len(new_chunk) - if self.decrypt_chunk_length > chunk_size: - await f.write(new_chunk) - else: - chunk_to_decrypt = new_chunk[: self.decrypt_chunk_length] - decryptor = self.cipher.decryptor() - dec_data = decryptor.update(chunk_to_decrypt) + decryptor.finalize() - await f.write(dec_data) - await f.write(new_chunk[self.decrypt_chunk_length :]) - - class PyDeezer(Deezer): def __init__(self, arl=None): super().__init__() diff --git a/src/pyramid/connector/deezer/download/decrypt.py b/src/pyramid/connector/deezer/download/decrypt.py new file mode 100644 index 0000000..5d7094b --- /dev/null +++ b/src/pyramid/connector/deezer/download/decrypt.py @@ -0,0 +1,88 @@ +import warnings + +import aiofiles +import aiohttp +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, modes + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from cryptography.hazmat.primitives.ciphers.algorithms import Blowfish + +from pydeezer.ProgressHandler import BaseProgressHandler + + +class DecryptDeezer: + def __init__(self, blowfish_key: bytes, progress_handler: BaseProgressHandler) -> None: + self.chunk_length = 6144 + self.decrypt_chunk_length = 2048 + self.cipher = Cipher( + Blowfish(blowfish_key), + modes.CBC(bytes([i for i in range(8)])), + default_backend(), + ) + self.progress_handler = progress_handler + + async def output_file(self, filesize: int, file_path: str, res: aiohttp.ClientResponse): + async with aiofiles.open(file_path, "wb") as f: + await f.seek(0) + + chunk_index = 0 + downloaded_size = 0 + previous_chunk = None + async for chunk, _ in res.content.iter_chunks(): + chunk_size = len(chunk) + self.progress_handler.update(current_chunk_size=chunk_size) + + downloaded_size += chunk_size + if previous_chunk: + previous_chunk, chunks_used = await self._transform_chunk( + f, previous_chunk + chunk + ) + else: + previous_chunk, chunks_used = await self._transform_chunk(f, chunk) + chunk_index += chunks_used + + if previous_chunk: + await self._write_file(f, previous_chunk) + if downloaded_size != filesize: + missing = filesize - downloaded_size + raise Exception("[%s] %d bytes are missing" % (filesize, missing)) + + async def _transform_chunk( + self, f: aiofiles.threadpool.binary.AsyncBufferedIOBase, bytes_chunked: bytes + ) -> tuple[bytes, int] | tuple[None, int]: + # Calculate the number of chunks needed + length_bytes = len(bytes_chunked) + chunks_nb = int(length_bytes / self.chunk_length) + 1 + + # Iterate over the chunks and call the callback for each one + for i in range(chunks_nb - 1): + chunk = bytes_chunked[i * self.chunk_length : (i + 1) * self.chunk_length] + await self._write_file(f, chunk) + + last_chunk_start = (chunks_nb - 1) * self.chunk_length + last_chunk = bytes_chunked[last_chunk_start:] + last_length = len(last_chunk) + + if last_length == self.chunk_length: + await self._write_file(f, last_chunk) + return None, chunks_nb + elif last_length < self.chunk_length: + return last_chunk, chunks_nb - 1 + raise Exception( + "Last chunk has wrong size %d (under %d is excepted)", last_length, self.chunk_length + ) + + async def _write_file( + self, f: aiofiles.threadpool.binary.AsyncBufferedIOBase, new_chunk: bytes + ): + chunk_size = len(new_chunk) + if self.decrypt_chunk_length > chunk_size: + await f.write(new_chunk) + else: + chunk_to_decrypt = new_chunk[: self.decrypt_chunk_length] + decryptor = self.cipher.decryptor() + dec_data = decryptor.update(chunk_to_decrypt) + decryptor.finalize() + await f.write(dec_data) + await f.write(new_chunk[self.decrypt_chunk_length :]) diff --git a/src/pyramid/connector/deezer/download/exception.py b/src/pyramid/connector/deezer/download/exception.py new file mode 100644 index 0000000..93a054e --- /dev/null +++ b/src/pyramid/connector/deezer/download/exception.py @@ -0,0 +1,7 @@ + + +from pyramid.data.exceptions import CustomException + + +class DlDeezerNotUrlFoundException(CustomException): + pass diff --git a/src/pyramid/connector/deezer/downloader.py b/src/pyramid/connector/deezer/downloader.py index 2c3165d..a70e261 100644 --- a/src/pyramid/connector/deezer/downloader.py +++ b/src/pyramid/connector/deezer/downloader.py @@ -5,10 +5,9 @@ from typing import Optional import pydeezer.util +from pyramid.connector.deezer.download.client import PyDeezer from pyramid.connector.deezer.downloader_progress_bar import DownloaderProgressBar -from pyramid.connector.deezer.py_deezer import PyDeezer -from pyramid.data.track import Track -from pyramid.tools.generate_token import DeezerTokenProvider +from pyramid.data.music.track import Track from pyramid.tools.deprecated_class import deprecated_class from pyramid.tools.generate_token import DeezerTokenProvider from pydeezer.constants import track_formats diff --git a/src/pyramid/connector/deezer/search.py b/src/pyramid/connector/deezer/search.py index fad9c3b..b4e08c0 100644 --- a/src/pyramid/connector/deezer/search.py +++ b/src/pyramid/connector/deezer/search.py @@ -5,16 +5,12 @@ import aiohttp import deezer -from pyramid.data.a_engine_tools import AEngineTools +from pyramid.connector.deezer.client.exceptions import CliDeezerNoDataException, CliDeezerRateLimitError +from pyramid.connector.deezer.client.list_paginated import DeezerListPaginated from pyramid.data.a_search import ASearch, ASearchId -from pyramid.data.track import TrackMinimalDeezer - -from pyramid.connector.deezer.cli_deezer import ( - CliDeezer, - CliDeezerNoDataException, - CliDeezerRateLimitError, - CliPaginatedList, -) +from pyramid.data.music.track_minimal_deezer import TrackMinimalDeezer + +from pyramid.connector.deezer.cli_deezer import CliDeezer from pyramid.services.deezer_search import DeezerTools, DeezerType from pyramid.tools.deprecated_class import deprecated_class @@ -23,7 +19,6 @@ class DeezerSearch(ASearchId, ASearch): def __init__(self, default_limit: int): self.default_limit = default_limit self.client = CliDeezer() - self.tools = DeezerTools() self.strict = False async def search_track(self, search) -> TrackMinimalDeezer | None: @@ -66,7 +61,7 @@ async def get_playlist_tracks(self, playlist_name) -> list[TrackMinimalDeezer] | playlist = await pagination_results.get_first() if not playlist: return None - pagination_tracks: CliPaginatedList[deezer.Track] = playlist.get_tracks() # type: ignore + pagination_tracks: DeezerListPaginated[deezer.Track] = playlist.get_tracks() # type: ignore tracks = await pagination_tracks.get_all() return [TrackMinimalDeezer(element) for element in tracks] @@ -77,7 +72,7 @@ async def get_playlist_tracks_by_id( if not playlist: return None # Tracks id are the not the good one - playlist_tracks: CliPaginatedList[deezer.Track] = playlist.get_tracks() # type: ignore + playlist_tracks: DeezerListPaginated[deezer.Track] = playlist.get_tracks() # type: ignore # So we search the id for same name and artist real_tracks: list[TrackMinimalDeezer] = [] @@ -107,7 +102,7 @@ async def get_album_tracks(self, album_name) -> list[TrackMinimalDeezer] | None: album = await pagination_results.get_first() if not album: return None - pagination_tracks: CliPaginatedList[deezer.Track] = album.get_tracks() # type: ignore + pagination_tracks: DeezerListPaginated[deezer.Track] = album.get_tracks() # type: ignore tracks = await pagination_tracks.get_all() return [TrackMinimalDeezer(element) for element in tracks] @@ -117,7 +112,7 @@ async def get_album_tracks_by_id( album = await self.client.async_get_album(album_id) # TODO handle HTTP errors if not album: return None - pagination_tracks: CliPaginatedList[deezer.Track] = album.get_tracks() # type: ignore + pagination_tracks: DeezerListPaginated[deezer.Track] = album.get_tracks() # type: ignore tracks = await pagination_tracks.get_all() return [TrackMinimalDeezer(element) for element in tracks], [] @@ -130,7 +125,7 @@ async def get_top_artist( artist = await pagination_results.get_first() if not artist: return None - pagination_tracks: CliPaginatedList[deezer.Track] = artist.get_top() # type: ignore + pagination_tracks: DeezerListPaginated[deezer.Track] = artist.get_top() # type: ignore tracks = await pagination_tracks.get_maximum(limit) return [TrackMinimalDeezer(element) for element in tracks] @@ -142,14 +137,14 @@ async def get_top_artist_by_id( artist = await self.client.async_get_artist(artist_id) # TODO handle HTTP errors if not artist: return None - pagination_tracks: CliPaginatedList[deezer.Track] = artist.get_top() # type: ignore + pagination_tracks: DeezerListPaginated[deezer.Track] = artist.get_top() # type: ignore tracks = await pagination_tracks.get_maximum(limit) return [TrackMinimalDeezer(element) for element in tracks], [] async def get_by_url( self, url ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | TrackMinimalDeezer | None: - id, type = await self.tools.extract_from_url(url) + id, type = await DeezerTools.extract_from_url(url) if id is None: return None diff --git a/src/pyramid/connector/deezer/tools.py b/src/pyramid/connector/deezer/tools.py index a66ecb8..9310015 100644 --- a/src/pyramid/connector/deezer/tools.py +++ b/src/pyramid/connector/deezer/tools.py @@ -6,7 +6,9 @@ class DeezerTools(AEngineTools): - async def extract_from_url(self, url) -> tuple[int, DeezerType | None] | tuple[None, None]: + + @classmethod + async def extract_from_url(cls, url) -> tuple[int, DeezerType | None] | tuple[None, None]: # Resolve if URL is a deezer.page.link URL if "deezer.page.link" in url: async with aiohttp.ClientSession() as session: diff --git a/src/pyramid/connector/discord/guild_cmd.py b/src/pyramid/connector/discord/guild_cmd.py index f5f791f..3fcabfa 100644 --- a/src/pyramid/connector/discord/guild_cmd.py +++ b/src/pyramid/connector/discord/guild_cmd.py @@ -7,7 +7,7 @@ from pyramid.data import tracklist as utils_list_track from pyramid.data.guild_data import GuildData from pyramid.data.source_type import SourceType -from pyramid.data.track import TrackMinimal +from pyramid.data.music.track_minimal import TrackMinimal from pyramid.connector.discord.guild_cmd_tools import GuildCmdTools from pyramid.connector.discord.guild_queue import GuildQueue from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued diff --git a/src/pyramid/connector/discord/guild_cmd_tools.py b/src/pyramid/connector/discord/guild_cmd_tools.py index baa78ca..489cb9a 100644 --- a/src/pyramid/connector/discord/guild_cmd_tools.py +++ b/src/pyramid/connector/discord/guild_cmd_tools.py @@ -6,7 +6,9 @@ from pyramid.data.exceptions import DeezerTokenException from pyramid.api.services.source_service import ISourceService -from pyramid.data.track import Track, TrackMinimal, TrackMinimalDeezer +from pyramid.data.music.track import Track +from pyramid.data.music.track_minimal_deezer import TrackMinimalDeezer +from pyramid.data.music.track_minimal import TrackMinimal from pyramid.data.guild_data import GuildData from pyramid.data.tracklist import TrackList from pyramid.connector.discord.guild_queue import GuildQueue diff --git a/src/pyramid/connector/discord/guild_queue.py b/src/pyramid/connector/discord/guild_queue.py index ab98d73..c295fa2 100644 --- a/src/pyramid/connector/discord/guild_queue.py +++ b/src/pyramid/connector/discord/guild_queue.py @@ -4,10 +4,10 @@ import discord from discord import VoiceChannel, VoiceClient -from pyramid.data.track import Track +from pyramid.data.music.track import Track from pyramid.data.tracklist import TrackList from pyramid.data.guild_data import GuildData -from pyramid.connector.discord.music_player_interface import MusicPlayerInterface +from pyramid.connector.discord.music.player_interface import MusicPlayerInterface from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued from pyramid.data.a_guid_queue import AGuildQueue diff --git a/src/pyramid/connector/discord/music/player_button.py b/src/pyramid/connector/discord/music/player_button.py new file mode 100644 index 0000000..138769c --- /dev/null +++ b/src/pyramid/connector/discord/music/player_button.py @@ -0,0 +1,35 @@ +import discord +from discord import ButtonStyle, Interaction +from discord.ui import Button + +from pyramid.data.a_guild_cmd import AGuildCmd +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued + +class PlayerButton(discord.ui.View): + def __init__(self, queue_action: AGuildCmd, timeout: float | None = 180): + super().__init__(timeout=timeout) + self.queue_action = queue_action + + @discord.ui.button(emoji="⏯️", style=ButtonStyle.primary) + async def next_play_or_pause(self, ctx: Interaction, button: Button): + ms = MessageSenderQueued(ctx) + await ms.thinking() + await self.queue_action.resume_or_pause(ms, ctx) + + @discord.ui.button(emoji="⏭️", style=ButtonStyle.primary) + async def next_track(self, ctx: Interaction, button: Button): + ms = MessageSenderQueued(ctx) + await ms.thinking() + await self.queue_action.next(ms, ctx) + + @discord.ui.button(emoji="🔀", style=ButtonStyle.primary) + async def shuffle_queue(self, ctx: Interaction, button: Button): + ms = MessageSenderQueued(ctx) + await ms.thinking() + await self.queue_action.shuffle(ms, ctx) + + @discord.ui.button(emoji="⏹️", style=ButtonStyle.primary) + async def stop_queue(self, ctx: Interaction, button: Button): + ms = MessageSenderQueued(ctx) + await ms.thinking() + await self.queue_action.stop(ms, ctx) \ No newline at end of file diff --git a/src/pyramid/connector/discord/music_player_interface.py b/src/pyramid/connector/discord/music/player_interface.py similarity index 64% rename from src/pyramid/connector/discord/music_player_interface.py rename to src/pyramid/connector/discord/music/player_interface.py index e24326d..7d90455 100644 --- a/src/pyramid/connector/discord/music_player_interface.py +++ b/src/pyramid/connector/discord/music/player_interface.py @@ -1,42 +1,11 @@ import discord -from discord import ButtonStyle, Embed, Interaction, Locale, Message +from discord import Embed, Locale, Message from discord.abc import Messageable -from discord.ui import Button -from pyramid.data.track import Track +from pyramid.connector.discord.music.player_button import PlayerButton +from pyramid.data.music.track import Track from pyramid.data.tracklist import TrackList from pyramid.data.a_guild_cmd import AGuildCmd -from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued - - -class PlayerButton(discord.ui.View): - def __init__(self, queue_action: AGuildCmd, timeout: float | None = 180): - super().__init__(timeout=timeout) - self.queue_action = queue_action - - @discord.ui.button(emoji="⏯️", style=ButtonStyle.primary) - async def next_play_or_pause(self, ctx: Interaction, button: Button): - ms = MessageSenderQueued(ctx) - await ms.thinking() - await self.queue_action.resume_or_pause(ms, ctx) - - @discord.ui.button(emoji="⏭️", style=ButtonStyle.primary) - async def next_track(self, ctx: Interaction, button: Button): - ms = MessageSenderQueued(ctx) - await ms.thinking() - await self.queue_action.next(ms, ctx) - - @discord.ui.button(emoji="🔀", style=ButtonStyle.primary) - async def shuffle_queue(self, ctx: Interaction, button: Button): - ms = MessageSenderQueued(ctx) - await ms.thinking() - await self.queue_action.shuffle(ms, ctx) - - @discord.ui.button(emoji="⏹️", style=ButtonStyle.primary) - async def stop_queue(self, ctx: Interaction, button: Button): - ms = MessageSenderQueued(ctx) - await ms.thinking() - await self.queue_action.stop(ms, ctx) class MusicPlayerInterface: diff --git a/src/pyramid/connector/spotify/search.py b/src/pyramid/connector/spotify/search.py index df02dd0..7fb5146 100644 --- a/src/pyramid/connector/spotify/search.py +++ b/src/pyramid/connector/spotify/search.py @@ -3,7 +3,7 @@ from pyramid.connector.spotify.spotify_tools import SpotifyTools from pyramid.connector.spotify.spotify_type import SpotifyType from pyramid.data.a_search import ASearch, ASearchId -from pyramid.data.track import TrackMinimalSpotify +from pyramid.data.music.track_minimal_spotify import TrackMinimalSpotify from spotipy.oauth2 import SpotifyClientCredentials from pyramid.connector.spotify.cli_spotify import CliSpotify diff --git a/src/pyramid/data/a_guid_queue.py b/src/pyramid/data/a_guid_queue.py index 5383ed8..c85ab30 100644 --- a/src/pyramid/data/a_guid_queue.py +++ b/src/pyramid/data/a_guid_queue.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional -from pyramid.data.track import Track +from pyramid.data.music.track import Track class AGuildQueue(ABC): diff --git a/src/pyramid/data/a_search.py b/src/pyramid/data/a_search.py index 6359e5b..a220fc1 100644 --- a/src/pyramid/data/a_search.py +++ b/src/pyramid/data/a_search.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from pyramid.data.track import TrackMinimal +from pyramid.data.music.track_minimal import TrackMinimal class ASearch(ABC): diff --git a/src/pyramid/data/guild_instance.py b/src/pyramid/data/guild_instance.py index bf44d12..8ccf68e 100644 --- a/src/pyramid/data/guild_instance.py +++ b/src/pyramid/data/guild_instance.py @@ -6,7 +6,7 @@ from pyramid.api.services.source_service import ISourceService from pyramid.connector.discord.guild_cmd import GuildCmd from pyramid.connector.discord.guild_queue import GuildQueue -from pyramid.connector.discord.music_player_interface import MusicPlayerInterface +from pyramid.connector.discord.music.player_interface import MusicPlayerInterface from pyramid.data.guild_data import GuildData diff --git a/src/pyramid/data/music/track.py b/src/pyramid/data/music/track.py new file mode 100644 index 0000000..d7401f0 --- /dev/null +++ b/src/pyramid/data/music/track.py @@ -0,0 +1,38 @@ +from datetime import datetime + +from pyramid.data.music.track_minimal import TrackMinimal +from pyramid.tools import utils + +class Track(TrackMinimal): + def __init__(self, data, file_path): + self.author_name: str = data["ART_NAME"] + self.author_picture: str = f"https://e-cdn-images.dzcdn.net/images/artist/{data['ART_PICTURE']}/512x512-000000-80-0-0.jpg" + self.authors: list[str] = [element["ART_NAME"] for element in data["ARTISTS"]] + self.name: str = data["SNG_TITLE"] + self.album_title: str = data["ALB_TITLE"] + self.album_picture: str = f"https://e-cdn-images.dzcdn.net/images/cover/{data['ALB_PICTURE']}/1024x1024-000000-80-0-0.jpg" + self.actual_seconds: int = int(0) + self.duration_seconds: int = int(data["DURATION"]) + self.duration: str = self.format_duration(int(data["DURATION"])) + self.file_size: int = int(data["FILESIZE"]) + if self.__is_valid_date(data["PHYSICAL_RELEASE_DATE"]): + self.date = datetime.strptime(data["PHYSICAL_RELEASE_DATE"], "%Y-%m-%d") + else: + self.date = None + self.file_local: str = file_path + + def get_date(self, locale: str = "en-US") -> str | None: + if self.date is None: + return None + date_formatted = utils.format_date_by_country(self.date, locale) + return date_formatted + + def __is_valid_date(self, date: str): + # Check if the format is exactly "YYYY-MM-DD" + parts = date.split("-") + if len(parts) == 3 and len(parts[0]) == 4 and len(parts[1]) == 2 and len(parts[2]) == 2: + year, month, day = parts + if year.isdigit() and month.isdigit() and day.isdigit(): + if 1 <= int(month) <= 12 and 1 <= int(day) <= 31: + return True + return False diff --git a/src/pyramid/data/music/track_minimal.py b/src/pyramid/data/music/track_minimal.py new file mode 100644 index 0000000..ef5a021 --- /dev/null +++ b/src/pyramid/data/music/track_minimal.py @@ -0,0 +1,35 @@ +from abc import ABC + + +class TrackMinimal(ABC): + def __init__(self, data): + self.id: str + self.author_name: str + self.author_picture: str + self.name: str + self.album_title: str + self.album_picture: str + self.available = True + + def get_full_name(self) -> str: + return f"{self.author_name} - {self.name}" + + def format_duration(self, input: int) -> str: + seconds = int(input) + + minutes, seconds = divmod(seconds, 60) + hours, minutes = divmod(minutes, 60) + days, hours = divmod(hours, 24) + + time_format = "{:02}:{:02}:{:02}".format(hours, minutes, seconds) + + if days > 0: + time_format = "{:02}d ".format(days) + time_format + + if days == 0 and hours == 0: + time_format = "{:02}:{:02}".format(minutes, seconds) + + return time_format + + def __str__(self): + return f"{self.author_name} - {self.name} - {self.album_title}" diff --git a/src/pyramid/data/music/track_minimal_deezer.py b/src/pyramid/data/music/track_minimal_deezer.py new file mode 100644 index 0000000..725527b --- /dev/null +++ b/src/pyramid/data/music/track_minimal_deezer.py @@ -0,0 +1,25 @@ + +import deezer + +from pyramid.data.music.track_minimal import TrackMinimal + +class TrackMinimalDeezer(TrackMinimal): + def __init__(self, data: deezer.Track): + self.id: str = str(data.id) + self.author_name: str = data.artist.name + # self.author_picture: str = data.artist.picture_medium + self.name: str = data.title + self.album_title: str = data.album.title + self.album_picture: str = data.album.cover_xl + # self.disk_number: int = data.disk_number + # self.track_number: int = data.track_position + self.explicit: bool = data.explicit_lyrics + self.duration: int = data.duration + self.rank: int = data.rank + # self.isrc: str = data.isrc + # self.available_countries: list[str] = data.available_countries + if not data.readable: + # logging.warning("%s - %s is unreadable", self.author_name, self.name) + self.available = False + else: + self.available = True diff --git a/src/pyramid/data/music/track_minimal_spotify.py b/src/pyramid/data/music/track_minimal_spotify.py new file mode 100644 index 0000000..d327739 --- /dev/null +++ b/src/pyramid/data/music/track_minimal_spotify.py @@ -0,0 +1,25 @@ + + +from pyramid.data.music.track_minimal import TrackMinimal + + +class TrackMinimalSpotify(TrackMinimal): + def __init__(self, data): + # author_picture = data['artists'][0]['images'][0]['url'] if data['artists'][0]['images'] else "" + # author_picture : str = "" # TODO Fix it + album_picture = data["album"]["images"][0]["url"] if data["album"]["images"] else "" + self.id: str = data["id"] + self.author_name: str = data["artists"][0]["name"] + self.author_picture: str = "" + self.name: str = data["name"] + self.album_title: str = data["album"]["name"] + self.album_picture: str = album_picture + self.disk_number: int = data["disc_number"] + self.track_number: int = data["track_number"] + self.explicit: bool = data["explicit"] + self.duration: int = round(data["duration_ms"] / 1000) + self.isrc: str | None = ( + data["external_ids"]["isrc"] if "isrc" in data["external_ids"] else None + ) + # self.available_countries: list[str] = data["available_markets"] + self.available = not data["is_local"] diff --git a/src/pyramid/data/track.py b/src/pyramid/data/track.py deleted file mode 100644 index b7e1789..0000000 --- a/src/pyramid/data/track.py +++ /dev/null @@ -1,118 +0,0 @@ -from abc import ABC -from datetime import datetime - -import deezer -from pyramid.tools import utils - - -class TrackMinimal(ABC): - def __init__(self, data): - self.id: str - self.author_name: str - self.author_picture: str - self.name: str - self.album_title: str - self.album_picture: str - self.available = True - - def get_full_name(self) -> str: - return f"{self.author_name} - {self.name}" - - def format_duration(self, input: int) -> str: - seconds = int(input) - - minutes, seconds = divmod(seconds, 60) - hours, minutes = divmod(minutes, 60) - days, hours = divmod(hours, 24) - - time_format = "{:02}:{:02}:{:02}".format(hours, minutes, seconds) - - if days > 0: - time_format = "{:02}d ".format(days) + time_format - - if days == 0 and hours == 0: - time_format = "{:02}:{:02}".format(minutes, seconds) - - return time_format - - def __str__(self): - return f"{self.author_name} - {self.name} - {self.album_title}" - - -class TrackMinimalSpotify(TrackMinimal): - def __init__(self, data): - # author_picture = data['artists'][0]['images'][0]['url'] if data['artists'][0]['images'] else "" - # author_picture : str = "" # TODO Fix it - album_picture = data["album"]["images"][0]["url"] if data["album"]["images"] else "" - self.id: str = data["id"] - self.author_name: str = data["artists"][0]["name"] - self.author_picture: str = "" - self.name: str = data["name"] - self.album_title: str = data["album"]["name"] - self.album_picture: str = album_picture - self.disk_number: int = data["disc_number"] - self.track_number: int = data["track_number"] - self.explicit: bool = data["explicit"] - self.duration: int = round(data["duration_ms"] / 1000) - self.isrc: str | None = ( - data["external_ids"]["isrc"] if "isrc" in data["external_ids"] else None - ) - # self.available_countries: list[str] = data["available_markets"] - self.available = not data["is_local"] - - -class TrackMinimalDeezer(TrackMinimal): - def __init__(self, data: deezer.Track): - self.id: str = str(data.id) - self.author_name: str = data.artist.name - # self.author_picture: str = data.artist.picture_medium - self.name: str = data.title - self.album_title: str = data.album.title - self.album_picture: str = data.album.cover_xl - # self.disk_number: int = data.disk_number - # self.track_number: int = data.track_position - self.explicit: bool = data.explicit_lyrics - self.duration: int = data.duration - self.rank: int = data.rank - # self.isrc: str = data.isrc - # self.available_countries: list[str] = data.available_countries - if not data.readable: - # logging.warning("%s - %s is unreadable", self.author_name, self.name) - self.available = False - else: - self.available = True - - -class Track(TrackMinimal): - def __init__(self, data, file_path): - self.author_name: str = data["ART_NAME"] - self.author_picture: str = f"https://e-cdn-images.dzcdn.net/images/artist/{data['ART_PICTURE']}/512x512-000000-80-0-0.jpg" - self.authors: list[str] = [element["ART_NAME"] for element in data["ARTISTS"]] - self.name: str = data["SNG_TITLE"] - self.album_title: str = data["ALB_TITLE"] - self.album_picture: str = f"https://e-cdn-images.dzcdn.net/images/cover/{data['ALB_PICTURE']}/1024x1024-000000-80-0-0.jpg" - self.actual_seconds: int = int(0) - self.duration_seconds: int = int(data["DURATION"]) - self.duration: str = self.format_duration(int(data["DURATION"])) - self.file_size: int = int(data["FILESIZE"]) - if self.__is_valid_date(data["PHYSICAL_RELEASE_DATE"]): - self.date = datetime.strptime(data["PHYSICAL_RELEASE_DATE"], "%Y-%m-%d") - else: - self.date = None - self.file_local: str = file_path - - def get_date(self, locale: str = "en-US") -> str | None: - if self.date is None: - return None - date_formatted = utils.format_date_by_country(self.date, locale) - return date_formatted - - def __is_valid_date(self, date: str): - # Check if the format is exactly "YYYY-MM-DD" - parts = date.split("-") - if len(parts) == 3 and len(parts[0]) == 4 and len(parts[1]) == 2 and len(parts[2]) == 2: - year, month, day = parts - if year.isdigit() and month.isdigit() and day.isdigit(): - if 1 <= int(month) <= 12 and 1 <= int(day) <= 31: - return True - return False diff --git a/src/pyramid/data/tracklist.py b/src/pyramid/data/tracklist.py index fbe8d92..7b2724e 100644 --- a/src/pyramid/data/tracklist.py +++ b/src/pyramid/data/tracklist.py @@ -2,7 +2,9 @@ import random from pyramid.tools import utils -from pyramid.data.track import Track, TrackMinimal, TrackMinimalDeezer +from pyramid.data.music.track import Track +from pyramid.data.music.track_minimal_deezer import TrackMinimalDeezer +from pyramid.data.music.track_minimal import TrackMinimal class TrackList: diff --git a/src/pyramid/services/deezer_client.py b/src/pyramid/services/deezer_client.py new file mode 100644 index 0000000..bda205e --- /dev/null +++ b/src/pyramid/services/deezer_client.py @@ -0,0 +1,201 @@ +from typing import Any + +import aiohttp +from deezer import Album, Artist, Client, Playlist, Resource, Track + +from pyramid.api.services.deezer_client import IDeezerClientService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.connector.deezer.client.a_deezer_client import ADeezerClient +from pyramid.connector.deezer.client.exceptions import CliDeezerErrorResponse, CliDeezerHTTPError +from pyramid.connector.deezer.client.list_paginated import DeezerListPaginated +from pyramid.connector.deezer.client.rate_limiter_async import RateLimiterAsync + +@pyramid_service(interface=IDeezerClientService) +class DeezerClientService(IDeezerClientService, ADeezerClient, Client, ServiceInjector): + + def __init__(self, app_id=None, app_secret=None, access_token=None, headers=None, **kwargs): + # super(Client, self).__init__(app_id, app_secret, access_token, headers, **kwargs) + + self.app_id = app_id + self.app_secret = app_secret + self.access_token = access_token + # self.session = requests.Session() + + # headers = headers or {} + # self.session.headers.update(headers) + # self.session.close() + # self.async_session = aiohttp.ClientSession() + self.rate_limiter = RateLimiterAsync(max_requests=50, time_interval=5) + + def get_paginated_list( + self, + relation: str, + **kwargs, + ) -> DeezerListPaginated: + return DeezerListPaginated( + client=self.client, + base_path=f"{self.type}/{self.id}/{relation}", + parent=self, + **kwargs, + ) + Resource.get_paginated_list = get_paginated_list # type: ignore + + # def __getattr__(self, item: str) -> Any: + # try: + # return object.__getattribute__(self, item) + # except AttributeError: + # print(f"Attribute '{item}' not found.") + # Resource.__getattr__ = __getattr__ + + def get(self) -> Any: + raise AttributeError("%s has a missing attribute." % self.__class__.__name__) + + Resource.get = get + + def _search( + self, + path: str, + query: str = "", + strict: bool | None = None, + ordering: str | None = None, + **advanced_params: str | int | None, + ) -> DeezerListPaginated: + return Client._search(self, path, query, strict, ordering, **advanced_params) # type: ignore + + def search( + self, + query: str = "", + strict: bool | None = None, + ordering: str | None = None, + artist: str | None = None, + album: str | None = None, + track: str | None = None, + label: str | None = None, + dur_min: int | None = None, + dur_max: int | None = None, + bpm_min: int | None = None, + bpm_max: int | None = None, + ) -> DeezerListPaginated: + return self._search( + "", + query=query, + strict=strict, + ordering=ordering, + artist=artist, + album=album, + track=track, + label=label, + dur_min=dur_min, + dur_max=dur_max, + bpm_min=bpm_min, + bpm_max=bpm_max, + ) + + def search_playlists( + self, + query: str = "", + strict: bool | None = None, + ordering: str | None = None, + ) -> DeezerListPaginated[Playlist]: + return self._search( + path="playlist", + query=query, + strict=strict, + ordering=ordering, + ) + + def search_albums( + self, + query: str = "", + strict: bool | None = None, + ordering: str | None = None, + ) -> DeezerListPaginated[Album]: + return self._search( + path="album", + query=query, + strict=strict, + ordering=ordering, + ) + + def search_artists( + self, + query: str = "", + strict: bool | None = None, + ordering: str | None = None, + ) -> DeezerListPaginated[Artist]: + return self._search( + path="artist", + query=query, + strict=strict, + ordering=ordering, + ) + + async def async_get_playlist(self, playlist_id: int) -> Playlist: + return await self.async_request("GET", f"playlist/{playlist_id}") # type: ignore + + async def async_get_album(self, album_id: int) -> Album: + return await self.async_request("GET", f"album/{album_id}") # type: ignore + + async def async_get_artist(self, artist_id: int) -> Artist: + return await self.async_request("GET", f"artist/{artist_id}") # type: ignore + + async def async_get_track(self, track_id: int) -> Track: + return await self.async_request("GET", f"track/{track_id}") # type: ignore + + def _get_paginated_list(self, path, **params) -> DeezerListPaginated: + return DeezerListPaginated(client=self, base_path=path, **params) + + async def async_request( + self, + method: str, + path: str, + parent: Resource | None = None, + resource_type: type[Resource] | None = None, + resource_id: int | None = None, + paginate_list=False, + **params, + ): + """ + Make an asynchronous request to the API and parse the response. + + :param method: HTTP verb to use: GET, POST, DELETE, ... + :param path: The path to make the API call to (e.g. 'artist/1234'). + :param parent: A reference to the parent resource, to avoid fetching again. + :param resource_type: The resource class to use as the top level. + :param resource_id: The resource id to use as the top level. + :param paginate_list: Whether to wrap list into a pagination object. + :param params: Query parameters to add to the request + """ + + if self.access_token is not None: + params["access_token"] = str(self.access_token) + + async with aiohttp.ClientSession() as session: + await self.rate_limiter.check() + async with session.request( + method, + f"{self.base_url}/{path}", + params=params, + ) as response: + await self.rate_limiter.add() + try: + response.raise_for_status() + except aiohttp.ClientResponseError as exc: + raise CliDeezerHTTPError.from_status_code(exc) from exc + + json_data = await response.json() + + if not isinstance(json_data, dict): + return json_data + + if "error" in json_data and json_data["error"]: + raise CliDeezerErrorResponse.from_body(json_data) + + return self._process_json( + json_data, + parent=parent, + resource_type=resource_type, + resource_id=resource_id, + paginate_list=paginate_list, + ) diff --git a/src/pyramid/services/deezer_downloader.py b/src/pyramid/services/deezer_downloader.py index 858cc80..e6fe102 100644 --- a/src/pyramid/services/deezer_downloader.py +++ b/src/pyramid/services/deezer_downloader.py @@ -9,8 +9,8 @@ from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector from pyramid.connector.deezer.downloader_progress_bar import DownloaderProgressBar -from pyramid.connector.deezer.py_deezer import PyDeezer -from pyramid.data.track import Track +from pyramid.connector.deezer.download.client import PyDeezer +from pyramid.data.music.track import Track from pyramid.tools.generate_token import DeezerTokenProvider, DeezerTokenEmptyException, DeezerTokenOverflowException from pydeezer.constants import track_formats from pydeezer.exceptions import LoginError diff --git a/src/pyramid/services/deezer_search.py b/src/pyramid/services/deezer_search.py index f4f4669..7383a7f 100644 --- a/src/pyramid/services/deezer_search.py +++ b/src/pyramid/services/deezer_search.py @@ -1,56 +1,49 @@ import asyncio import logging -import re -from enum import Enum -import aiohttp import deezer from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.deezer_client import IDeezerClientService from pyramid.api.services.deezer_search import IDeezerSearchService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.connector.deezer.client.exceptions import CliDeezerNoDataException, CliDeezerRateLimitError +from pyramid.connector.deezer.client.list_paginated import DeezerListPaginated from pyramid.connector.deezer.deezer_type import DeezerType from pyramid.connector.deezer.tools import DeezerTools from pyramid.data.a_search import ASearch, ASearchId -from pyramid.data.track import TrackMinimalDeezer - -from pyramid.connector.deezer.cli_deezer import ( - CliDeezer, - CliDeezerNoDataException, - CliDeezerRateLimitError, - CliPaginatedList, -) +from pyramid.data.music.track_minimal_deezer import TrackMinimalDeezer @pyramid_service(interface=IDeezerSearchService) class DeezerSearchService(IDeezerSearchService, ASearchId, ASearch, ServiceInjector): def injectService(self, - configuration_service: IConfigurationService + configuration_service: IConfigurationService, + deezer_client: IDeezerClientService ): self.__configuration_service = configuration_service + self.__deezer_client = deezer_client def start(self): - self.client = CliDeezer() - self.tools = DeezerTools() self.strict = False async def search_track(self, search) -> TrackMinimalDeezer | None: - result = self.client.search(query=search) + result = self.__deezer_client.search(query=search) track = await result.get_first() if not track: return None return TrackMinimalDeezer(track) async def get_track_by_id(self, track_id: int) -> TrackMinimalDeezer | None: - track = await self.client.async_get_track(track_id) # TODO handle HTTP errors + track = await self.__deezer_client.async_get_track(track_id) # TODO handle HTTP errors if not track: return None return TrackMinimalDeezer(track) async def get_track_by_isrc(self, isrc: str) -> TrackMinimalDeezer | None: try: - track: deezer.Track = await self.client.async_request("GET", f"track/isrc:{isrc}") # type: ignore + track: deezer.Track = await self.__deezer_client.async_request("GET", f"track/isrc:{isrc}") # type: ignore if not track: return None return TrackMinimalDeezer(track) @@ -63,7 +56,7 @@ async def search_tracks( if limit is None: limit = self.__configuration_service.general__limit_tracks - pagination_results = self.client.search(query=search, strict=self.strict) + pagination_results = self.__deezer_client.search(query=search, strict=self.strict) tracks = await pagination_results.get_maximum(limit) if not tracks: return None @@ -71,22 +64,22 @@ async def search_tracks( return [TrackMinimalDeezer(element) for element in tracks] async def get_playlist_tracks(self, playlist_name) -> list[TrackMinimalDeezer] | None: - pagination_results = self.client.search_playlists(query=playlist_name, strict=self.strict) + pagination_results = self.__deezer_client.search_playlists(query=playlist_name, strict=self.strict) playlist = await pagination_results.get_first() if not playlist: return None - pagination_tracks: CliPaginatedList[deezer.Track] = playlist.get_tracks() # type: ignore + pagination_tracks: DeezerListPaginated[deezer.Track] = playlist.get_tracks() # type: ignore tracks = await pagination_tracks.get_all() return [TrackMinimalDeezer(element) for element in tracks] async def get_playlist_tracks_by_id( self, playlist_id: int ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: - playlist = await self.client.async_get_playlist(playlist_id) # TODO handle HTTP errors + playlist = await self.__deezer_client.async_get_playlist(playlist_id) # TODO handle HTTP errors if not playlist: return None # Tracks id are the not the good one - playlist_tracks: CliPaginatedList[deezer.Track] = playlist.get_tracks() # type: ignore + playlist_tracks: DeezerListPaginated[deezer.Track] = playlist.get_tracks() # type: ignore # So we search the id for same name and artist real_tracks: list[TrackMinimalDeezer] = [] @@ -112,21 +105,21 @@ async def get_playlist_tracks_by_id( return real_tracks, unfindable_track async def get_album_tracks(self, album_name) -> list[TrackMinimalDeezer] | None: - pagination_results = self.client.search_albums(query=album_name, strict=self.strict) + pagination_results = self.__deezer_client.search_albums(query=album_name, strict=self.strict) album = await pagination_results.get_first() if not album: return None - pagination_tracks: CliPaginatedList[deezer.Track] = album.get_tracks() # type: ignore + pagination_tracks: DeezerListPaginated[deezer.Track] = album.get_tracks() # type: ignore tracks = await pagination_tracks.get_all() return [TrackMinimalDeezer(element) for element in tracks] async def get_album_tracks_by_id( self, album_id: int ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: - album = await self.client.async_get_album(album_id) # TODO handle HTTP errors + album = await self.__deezer_client.async_get_album(album_id) # TODO handle HTTP errors if not album: return None - pagination_tracks: CliPaginatedList[deezer.Track] = album.get_tracks() # type: ignore + pagination_tracks: DeezerListPaginated[deezer.Track] = album.get_tracks() # type: ignore tracks = await pagination_tracks.get_all() return [TrackMinimalDeezer(element) for element in tracks], [] @@ -135,11 +128,11 @@ async def get_top_artist( ) -> list[TrackMinimalDeezer] | None: if limit is None: limit = self.__configuration_service.general__limit_tracks - pagination_results = self.client.search_artists(query=artist_name, strict=self.strict) + pagination_results = self.__deezer_client.search_artists(query=artist_name, strict=self.strict) artist = await pagination_results.get_first() if not artist: return None - pagination_tracks: CliPaginatedList[deezer.Track] = artist.get_top() # type: ignore + pagination_tracks: DeezerListPaginated[deezer.Track] = artist.get_top() # type: ignore tracks = await pagination_tracks.get_maximum(limit) return [TrackMinimalDeezer(element) for element in tracks] @@ -148,17 +141,17 @@ async def get_top_artist_by_id( ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: if limit is None: limit = self.__configuration_service.general__limit_tracks - artist = await self.client.async_get_artist(artist_id) # TODO handle HTTP errors + artist = await self.__deezer_client.async_get_artist(artist_id) # TODO handle HTTP errors if not artist: return None - pagination_tracks: CliPaginatedList[deezer.Track] = artist.get_top() # type: ignore + pagination_tracks: DeezerListPaginated[deezer.Track] = artist.get_top() # type: ignore tracks = await pagination_tracks.get_maximum(limit) return [TrackMinimalDeezer(element) for element in tracks], [] async def get_by_url( self, url ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | TrackMinimalDeezer | None: - id, type = await self.tools.extract_from_url(url) + id, type = await DeezerTools.extract_from_url(url) if id is None: return None @@ -213,7 +206,7 @@ async def _search_exact_track( self, artist_name, album_title, track_title ) -> TrackMinimalDeezer | None: try: - pagination_results = self.client.search( + pagination_results = self.__deezer_client.search( artist=artist_name, album=album_title, track=track_title ) logging.info("_search_exact_track %s - %s - %s", artist_name, album_title, track_title) diff --git a/src/pyramid/services/source.py b/src/pyramid/services/source.py index 8386dbd..589dadd 100644 --- a/src/pyramid/services/source.py +++ b/src/pyramid/services/source.py @@ -1,17 +1,17 @@ from typing import Dict -from pyramid.api.services.configuration import IConfigurationService from pyramid.api.services.deezer_downloader import IDeezerDownloaderService from pyramid.api.services.deezer_search import IDeezerSearchService from pyramid.api.services.source_service import ISourceService from pyramid.api.services.spotify_search import ISpotifySearchService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector -from pyramid.connector.spotify.search import SpotifySearch from pyramid.data.a_search import ASearch from pyramid.data.exceptions import EngineSourceNotFoundException, TrackNotFoundException from pyramid.data.source_type import SourceType -from pyramid.data.track import Track, TrackMinimal, TrackMinimalDeezer +from pyramid.data.music.track import Track +from pyramid.data.music.track_minimal_deezer import TrackMinimalDeezer +from pyramid.data.music.track_minimal import TrackMinimal @pyramid_service(interface=ISourceService) @@ -115,7 +115,7 @@ def _get_engine_name(self, engine: ASearch): for key, value in self.__sources.items(): if value == engine: return key.name - return None + raise Exception("Engine %s not found" % engine.__class__.__name__) def _resolve_engine(self, engine: SourceType | None): if engine is None: diff --git a/src/pyramid/services/spotify_client.py b/src/pyramid/services/spotify_client.py index b4564db..321d68f 100644 --- a/src/pyramid/services/spotify_client.py +++ b/src/pyramid/services/spotify_client.py @@ -6,15 +6,28 @@ import aiohttp from spotipy import Spotify from spotipy.exceptions import SpotifyException +from spotipy.oauth2 import SpotifyClientCredentials +from pyramid.api.services.configuration import IConfigurationService from pyramid.api.services.spotify_client import ISpotifyClientService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector - @pyramid_service(interface=ISpotifyClientService) class SpotifyClientService(Spotify, ISpotifyClientService, ServiceInjector): + def injectService(self, + configuration_service: IConfigurationService + ): + self.__configuration_service = configuration_service + + def start(self): + self.client_credentials_manager = SpotifyClientCredentials( + client_id=self.__configuration_service.spotify__client_id, + client_secret=self.__configuration_service.spotify__client_secret + ) + self.auth_manager = None + async def async_search(self, q, limit=10, offset=0, type="track", market=None) -> dict[str, Any]: return await self._get_async( "search", q=q, limit=limit, offset=offset, type=type, market=market diff --git a/src/pyramid/services/spotify_search.py b/src/pyramid/services/spotify_search.py index 186a3f8..0da0c9d 100644 --- a/src/pyramid/services/spotify_search.py +++ b/src/pyramid/services/spotify_search.py @@ -7,7 +7,7 @@ from pyramid.api.services.tools.injector import ServiceInjector from pyramid.connector.spotify.spotify_tools import SpotifyTools from pyramid.connector.spotify.spotify_type import SpotifyType -from pyramid.data.track import TrackMinimalSpotify +from pyramid.data.music.track_minimal_spotify import TrackMinimalSpotify @pyramid_service(interface=ISpotifySearchService) diff --git a/src/pyramid/services/spotify_search_base.py b/src/pyramid/services/spotify_search_base.py index 10e21f2..db5ab28 100644 --- a/src/pyramid/services/spotify_search_base.py +++ b/src/pyramid/services/spotify_search_base.py @@ -3,11 +3,10 @@ from typing import Any from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.spotify_client import ISpotifyClientService from pyramid.api.services.spotify_search_base import ISpotifySearchBaseService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector -from pyramid.connector.spotify.spotify_tools import SpotifyTools -from spotipy.oauth2 import SpotifyClientCredentials from pyramid.connector.spotify.cli_spotify import CliSpotify @@ -16,15 +15,13 @@ class SpotifySearchBaseService(ISpotifySearchBaseService, ServiceInjector): def injectService(self, configuration_service: IConfigurationService, + spotify_client: ISpotifyClientService, ): self.__configuration_service = configuration_service + self.__spotify_client = spotify_client - def start(self): - self.client_credentials_manager = SpotifyClientCredentials( - client_id=self.__configuration_service.spotify__client_id, - client_secret=self.__configuration_service.spotify__client_secret - ) - self.client = CliSpotify(client_credentials_manager=self.client_credentials_manager) + # def start(self): + # self.__spotify_client = CliSpotify(client_credentials_manager=self.client_credentials_manager) async def items( self, @@ -36,7 +33,7 @@ async def items( tracks: list = results[item_name] while results["next"]: - results = await self.client.async_next(results) # type: ignore + results = await self.__spotify_client.async_next(results) # type: ignore tracks.extend(results[item_name]) return tracks @@ -56,7 +53,7 @@ async def items_max( results_tracks: dict[str, Any] = results["tracks"] while results["tracks"]["next"] and limit > len(tracks): - results = await self.client.async_next(results_tracks) # type: ignore + results = await self.__spotify_client.async_next(results_tracks) # type: ignore tracks.extend(results_tracks[item_name]) if len(tracks) > limit: diff --git a/src/pyramid/services/spotify_search_id.py b/src/pyramid/services/spotify_search_id.py index 4d6fccf..d02411e 100644 --- a/src/pyramid/services/spotify_search_id.py +++ b/src/pyramid/services/spotify_search_id.py @@ -6,7 +6,7 @@ from pyramid.api.services.spotify_search_id import ISpotifySearchIdService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector -from pyramid.data.track import TrackMinimalSpotify +from pyramid.data.music.track_minimal_spotify import TrackMinimalSpotify @pyramid_service(interface=ISpotifySearchIdService) class SpotifySearchIdService(ISpotifySearchIdService, ServiceInjector): From 269cae9da19882909eaf154bcac69d658d3b27eb Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Mon, 23 Sep 2024 01:35:29 +0200 Subject: [PATCH 14/32] feat: remove deprecated class, use service in tests --- src/cli.py | 10 +- src/pyramid/api/services/configuration.py | 25 +- src/pyramid/api/services/tools/annotation.py | 5 +- src/pyramid/api/services/tools/exceptions.py | 2 +- src/pyramid/api/services/tools/register.py | 54 ++-- src/pyramid/api/services/tools/tester.py | 46 ++++ src/pyramid/connector/deezer/cli_deezer.py | 199 -------------- src/pyramid/connector/deezer/downloader.py | 141 ---------- src/pyramid/connector/deezer/search.py | 249 ------------------ src/pyramid/connector/spotify/cli_spotify.py | 135 ---------- src/pyramid/connector/spotify/search.py | 229 ---------------- src/pyramid/services/builder/configuration.py | 48 ++++ src/pyramid/services/configuration.py | 17 +- src/pyramid/services/deezer_downloader.py | 94 ++++--- src/pyramid/services/spotify_client.py | 1 - src/pyramid/services/spotify_search_base.py | 5 - .../tools/configuration/configuration.py | 66 ----- src/pyramid/tools/logs_handler.py | 89 ------- tests/config_test.py | 10 +- tests/deezer_download_test.py | 18 +- tests/deezer_search_test.py | 40 +-- tests/spotify_test.py | 39 +-- 22 files changed, 272 insertions(+), 1250 deletions(-) create mode 100644 src/pyramid/api/services/tools/tester.py delete mode 100644 src/pyramid/connector/deezer/cli_deezer.py delete mode 100644 src/pyramid/connector/deezer/downloader.py delete mode 100644 src/pyramid/connector/deezer/search.py delete mode 100644 src/pyramid/connector/spotify/cli_spotify.py delete mode 100644 src/pyramid/connector/spotify/search.py create mode 100644 src/pyramid/services/builder/configuration.py delete mode 100644 src/pyramid/tools/configuration/configuration.py delete mode 100644 src/pyramid/tools/logs_handler.py diff --git a/src/cli.py b/src/cli.py index 421b757..1330ce7 100644 --- a/src/cli.py +++ b/src/cli.py @@ -1,12 +1,12 @@ import argparse -from pyramid.data.functional.application_info import ApplicationInfo +from pyramid.api.services.information import IInformationService +from pyramid.api.services.tools.tester import ServiceStandalone from pyramid.client.client import SocketClient from pyramid.client.requests.health import HealthRequest -from pyramid.tools.logs_handler import LogsHandler def startup_cli(): - info = ApplicationInfo() + ServiceStandalone.import_services() parser = argparse.ArgumentParser(description="Readme at https://github.com/tristiisch/PyRamid") parser.add_argument("--version", action="store_true", help="Print version", required=False) @@ -23,7 +23,9 @@ def startup_cli(): args = parser.parse_args() if args.version: - print(info.get_version()) + information_service = ServiceStandalone.get_service(IInformationService) + information = information_service.get() + print(information.get_version()) elif args.health: sc = SocketClient(args.host, args.port) diff --git a/src/pyramid/api/services/configuration.py b/src/pyramid/api/services/configuration.py index acfe945..6309967 100644 --- a/src/pyramid/api/services/configuration.py +++ b/src/pyramid/api/services/configuration.py @@ -1,4 +1,4 @@ -from abc import ABC +from abc import ABC, abstractmethod from pyramid.data.environment import Environment class IConfigurationService(ABC): @@ -13,3 +13,26 @@ def __init__(self): self.general__limit_tracks: int = 0 self.mode: Environment = Environment.PRODUCTION self.version: str = "" + + @abstractmethod + def load(self, config_file: str = "config.yml", use_env_vars: bool = True) -> bool: + """ + Loads configuration values from environment variables and/or a configuration file. + + Parameters: + - use_env_vars (bool): If True, loads configuration values from environment variables. + + Returns: + - bool: True if the loading process is successful, False otherwise. + """ + pass + + @abstractmethod + def save(self, file_name: str): + """ + Saves the configuration values to a YAML file. + + Parameters: + - file_name (str): The name of the file to which the configuration will be saved. + """ + pass diff --git a/src/pyramid/api/services/tools/annotation.py b/src/pyramid/api/services/tools/annotation.py index f8360ea..3330e82 100644 --- a/src/pyramid/api/services/tools/annotation.py +++ b/src/pyramid/api/services/tools/annotation.py @@ -5,10 +5,7 @@ def pyramid_service(*, interface: Optional[type] = None): def decorator(cls): - class_name = cls.__name__ - if not issubclass(cls, ServiceInjector): - raise TypeError("Class %s must inherit from ServiceInjector" % class_name) - + class_name: str = cls.__name__ service_name = class_name if interface is not None: service_name = interface.__name__ diff --git a/src/pyramid/api/services/tools/exceptions.py b/src/pyramid/api/services/tools/exceptions.py index f181181..d00c891 100644 --- a/src/pyramid/api/services/tools/exceptions.py +++ b/src/pyramid/api/services/tools/exceptions.py @@ -4,7 +4,7 @@ class ServiceRegisterException(Exception): class ServiceAlreadyRegisterException(ServiceRegisterException): pass -class ServiceAlreadyNotRegisterException(ServiceRegisterException): +class ServiceNotRegisterException(ServiceRegisterException): pass class ServiceCicularDependencyException(ServiceRegisterException): diff --git a/src/pyramid/api/services/tools/register.py b/src/pyramid/api/services/tools/register.py index 44a06de..3932d65 100644 --- a/src/pyramid/api/services/tools/register.py +++ b/src/pyramid/api/services/tools/register.py @@ -3,16 +3,16 @@ import importlib import inspect import pkgutil -from typing import Type, TypeVar -from pyramid.api.services.tools.exceptions import ServiceAlreadyNotRegisterException, ServiceAlreadyRegisterException, ServiceCicularDependencyException +from typing import Any, Type, TypeVar +from pyramid.api.services.tools.exceptions import ServiceNotRegisterException, ServiceAlreadyRegisterException, ServiceCicularDependencyException from pyramid.api.services.tools.injector import ServiceInjector T = TypeVar('T') class ServiceRegister: - __SERVICE_TO_REGISTER: dict[str, type[ServiceInjector]] = {} - __SERVICE_REGISTERED: dict[str, ServiceInjector] = {} + __SERVICES_REGISTRED: dict[str, type[ServiceInjector]] = {} + __SERVICES_INSTANCES: dict[str, ServiceInjector] = {} @classmethod def import_services(cls): @@ -21,25 +21,25 @@ def import_services(cls): for loader, module_name, is_pkg in pkgutil.iter_modules(package.__path__): full_module_name = f"{package_name}.{module_name}" - importlib.import_module(full_module_name) + importlib.import_module(full_module_name) @classmethod def register_service(cls, name: str, type: type[object]): if not issubclass(type, ServiceInjector): raise TypeError("Service %s is not a subclass of ServiceInjector and cannot be initialized." % name) - if name in cls.__SERVICE_TO_REGISTER: - already_class_name = cls.__SERVICE_TO_REGISTER[name].__name__ + if name in cls.__SERVICES_REGISTRED: + already_class_name = cls.__SERVICES_REGISTRED[name].__name__ raise ServiceAlreadyRegisterException( "Cannot register %s with %s, it is already registered with the class %s." % (name, type.__name__, already_class_name) ) - cls.__SERVICE_TO_REGISTER[name] = type + cls.__SERVICES_REGISTRED[name] = type @classmethod def create_services(cls): - for name, service_type in cls.__SERVICE_TO_REGISTER.items(): + for name, service_type in cls.__SERVICES_REGISTRED.items(): class_instance = service_type() - cls.__SERVICE_REGISTERED[name] = class_instance + cls.__SERVICES_INSTANCES[name] = class_instance @classmethod def inject_services(cls): @@ -48,8 +48,8 @@ def inject_services(cls): indegree = defaultdict(int) # To track the number of dependencies # Create instances but delay injecting dependencies - for name, service_type in cls.__SERVICE_TO_REGISTER.items(): - class_instance = cls.__SERVICE_REGISTERED[name] + for name, service_type in cls.__SERVICES_REGISTRED.items(): + class_instance = cls.__SERVICES_INSTANCES[name] # Step 2: Parse dependencies for each service signature = inspect.signature(class_instance.injectService) @@ -57,8 +57,8 @@ def inject_services(cls): for method_parameter in method_parameters: dependency_name = method_parameter.annotation.__name__ - if dependency_name not in cls.__SERVICE_REGISTERED: - raise ServiceAlreadyNotRegisterException( + if dependency_name not in cls.__SERVICES_INSTANCES: + raise ServiceNotRegisterException( "Cannot register %s as a dependency for %s because the dependency is not registered." % (dependency_name, name) ) @@ -68,7 +68,7 @@ def inject_services(cls): # Step 3: Perform a topological sort to determine the order of instantiation sorted_services = [] - queue = deque([service for service in cls.__SERVICE_TO_REGISTER if indegree[service] == 0]) + queue = deque([service for service in cls.__SERVICES_REGISTRED if indegree[service] == 0]) while queue: service = queue.popleft() @@ -79,8 +79,8 @@ def inject_services(cls): if indegree[dependent] == 0: queue.append(dependent) - if len(sorted_services) != len(cls.__SERVICE_TO_REGISTER): - unresolved_services = set(cls.__SERVICE_TO_REGISTER) - set(sorted_services) + if len(sorted_services) != len(cls.__SERVICES_REGISTRED): + unresolved_services = set(cls.__SERVICES_REGISTRED) - set(sorted_services) raise ServiceCicularDependencyException( "Circular dependency detected! The following services are involved in a circular dependency: %s" % ', '.join(unresolved_services) @@ -88,14 +88,14 @@ def inject_services(cls): # Step 4: Inject dependencies in the correct order for service_name in sorted_services: - class_instance = cls.__SERVICE_REGISTERED[service_name] + class_instance = cls.__SERVICES_INSTANCES[service_name] signature = inspect.signature(class_instance.injectService) method_parameters = list(signature.parameters.values()) services_dependencies = [] for method_parameter in method_parameters: dependency_name = method_parameter.annotation.__name__ - dependency_instance = cls.__SERVICE_REGISTERED[dependency_name] + dependency_instance = cls.__SERVICES_INSTANCES[dependency_name] services_dependencies.append(dependency_instance) class_instance.injectService(*services_dependencies) @@ -104,7 +104,7 @@ def inject_services(cls): def get_dependency_tree(cls): # Step 1: Build dependency graph dependency_graph = defaultdict(list) - for name, class_instance in cls.__SERVICE_REGISTERED.items(): + for name, class_instance in cls.__SERVICES_INSTANCES.items(): signature = inspect.signature(class_instance.injectService) method_parameters = list(signature.parameters.values()) @@ -131,7 +131,7 @@ def build_tree(node, prefix="", last=True): build_tree(child, prefix, i == len(children) - 1) # Step 4: Find root services (those with no dependencies) - all_services = set(cls.__SERVICE_TO_REGISTER.keys()) + all_services = set(cls.__SERVICES_REGISTRED.keys()) dependent_services = set(dep for deps in dependency_graph.values() for dep in deps) root_services = all_services - dependent_services @@ -146,10 +146,18 @@ def build_tree(node, prefix="", last=True): @classmethod def start_services(cls): - for name, class_instance in cls.__SERVICE_REGISTERED.items(): + for name, class_instance in cls.__SERVICES_INSTANCES.items(): class_instance.start() + @classmethod + def get_service_registred(cls, class_name: str) -> type[ServiceInjector]: + if class_name not in cls.__SERVICES_REGISTRED: + raise ServiceNotRegisterException( + "Cannot get %s because the service is not registered." % (class_name) + ) + return cls.__SERVICES_REGISTRED[class_name] + @classmethod def get_service(cls, class_type: Type[T]) -> T: class_name = class_type.__name__ - return cls.__SERVICE_REGISTERED[class_name] + return cls.__SERVICES_INSTANCES[class_name] diff --git a/src/pyramid/api/services/tools/tester.py b/src/pyramid/api/services/tools/tester.py new file mode 100644 index 0000000..40e17d0 --- /dev/null +++ b/src/pyramid/api/services/tools/tester.py @@ -0,0 +1,46 @@ +import inspect +from typing import Optional, Type, TypeVar, cast +from pyramid.api.services.tools.exceptions import ServiceNotRegisterException +from pyramid.api.services.tools.register import ServiceRegister + +T = TypeVar('T') + +class ServiceStandalone: + + __SERVICE_REGISTERED: dict[str, object] = {} + + @classmethod + def import_services(cls): + ServiceRegister.import_services() + + @classmethod + def set_service(cls, service_interface: Type[T], service_instance: object): + service_name = service_interface.__name__ + cls.__SERVICE_REGISTERED[service_name] = service_instance + + @classmethod + def get_service(cls, service_interface: Type[T]) -> T: + service_name = service_interface.__name__ + + if service_name in cls.__SERVICE_REGISTERED: + return cast(T, cls.__SERVICE_REGISTERED[service_name]) + + service_type = ServiceRegister.get_service_registred(service_name) + class_instance = service_type() + + signature = inspect.signature(class_instance.injectService) + method_parameters = list(signature.parameters.values()) + + services_dependencies = [] + for method_parameter in method_parameters: + dependency = method_parameter.annotation + dependency_instance = cls.get_service(dependency) + + services_dependencies.append(dependency_instance) + + class_instance.injectService(*services_dependencies) + class_instance.start() + + cls.__SERVICE_REGISTERED[service_name] = class_instance + + return cast(T, class_instance) diff --git a/src/pyramid/connector/deezer/cli_deezer.py b/src/pyramid/connector/deezer/cli_deezer.py deleted file mode 100644 index 63f22eb..0000000 --- a/src/pyramid/connector/deezer/cli_deezer.py +++ /dev/null @@ -1,199 +0,0 @@ -from typing import Any - -import aiohttp -from deezer import Album, Artist, Client, Playlist, Resource, Track - -from pyramid.connector.deezer.client.a_deezer_client import ADeezerClient -from pyramid.connector.deezer.client.exceptions import CliDeezerErrorResponse, CliDeezerHTTPError -from pyramid.connector.deezer.client.list_paginated import DeezerListPaginated -from pyramid.connector.deezer.client.rate_limiter_async import RateLimiterAsync -from pyramid.tools.deprecated_class import deprecated_class - -@deprecated_class -class CliDeezer(ADeezerClient, Client): - def __init__(self, app_id=None, app_secret=None, access_token=None, headers=None, **kwargs): - # super().__init__(app_id, app_secret, access_token, headers, **kwargs) - - self.app_id = app_id - self.app_secret = app_secret - self.access_token = access_token - # self.session = requests.Session() - - # headers = headers or {} - # self.session.headers.update(headers) - # self.session.close() - # self.async_session = aiohttp.ClientSession() - self.rate_limiter = RateLimiterAsync(max_requests=50, time_interval=5) - - def get_paginated_list( - self, - relation: str, - **kwargs, - ) -> DeezerListPaginated: - return DeezerListPaginated( - client=self.client, - base_path=f"{self.type}/{self.id}/{relation}", - parent=self, - **kwargs, - ) - - Resource.get_paginated_list = get_paginated_list # type: ignore - - # def __getattr__(self, item: str) -> Any: - # try: - # return object.__getattribute__(self, item) - # except AttributeError: - # print(f"Attribute '{item}' not found.") - # Resource.__getattr__ = __getattr__ - - def get(self) -> Any: - raise AttributeError("%s has a missing attribute." % self.__class__.__name__) - - Resource.get = get - - def _search( - self, - path: str, - query: str = "", - strict: bool | None = None, - ordering: str | None = None, - **advanced_params: str | int | None, - ) -> DeezerListPaginated: - return super()._search(path, query, strict, ordering, **advanced_params) # type: ignore - - def search( - self, - query: str = "", - strict: bool | None = None, - ordering: str | None = None, - artist: str | None = None, - album: str | None = None, - track: str | None = None, - label: str | None = None, - dur_min: int | None = None, - dur_max: int | None = None, - bpm_min: int | None = None, - bpm_max: int | None = None, - ) -> DeezerListPaginated: - return self._search( - "", - query=query, - strict=strict, - ordering=ordering, - artist=artist, - album=album, - track=track, - label=label, - dur_min=dur_min, - dur_max=dur_max, - bpm_min=bpm_min, - bpm_max=bpm_max, - ) - - def search_playlists( - self, - query: str = "", - strict: bool | None = None, - ordering: str | None = None, - ) -> DeezerListPaginated[Playlist]: - return self._search( - path="playlist", - query=query, - strict=strict, - ordering=ordering, - ) - - def search_albums( - self, - query: str = "", - strict: bool | None = None, - ordering: str | None = None, - ) -> DeezerListPaginated[Album]: - return self._search( - path="album", - query=query, - strict=strict, - ordering=ordering, - ) - - def search_artists( - self, - query: str = "", - strict: bool | None = None, - ordering: str | None = None, - ) -> DeezerListPaginated[Artist]: - return self._search( - path="artist", - query=query, - strict=strict, - ordering=ordering, - ) - - async def async_get_playlist(self, playlist_id: int) -> Playlist: - return await self.async_request("GET", f"playlist/{playlist_id}") # type: ignore - - async def async_get_album(self, album_id: int) -> Album: - return await self.async_request("GET", f"album/{album_id}") # type: ignore - - async def async_get_artist(self, artist_id: int) -> Artist: - return await self.async_request("GET", f"artist/{artist_id}") # type: ignore - - async def async_get_track(self, track_id: int) -> Track: - return await self.async_request("GET", f"track/{track_id}") # type: ignore - - def _get_paginated_list(self, path, **params) -> DeezerListPaginated: - return DeezerListPaginated(client=self, base_path=path, **params) - - async def async_request( - self, - method: str, - path: str, - parent: Resource | None = None, - resource_type: type[Resource] | None = None, - resource_id: int | None = None, - paginate_list=False, - **params, - ): - """ - Make an asynchronous request to the API and parse the response. - - :param method: HTTP verb to use: GET, POST, DELETE, ... - :param path: The path to make the API call to (e.g. 'artist/1234'). - :param parent: A reference to the parent resource, to avoid fetching again. - :param resource_type: The resource class to use as the top level. - :param resource_id: The resource id to use as the top level. - :param paginate_list: Whether to wrap list into a pagination object. - :param params: Query parameters to add to the request - """ - - if self.access_token is not None: - params["access_token"] = str(self.access_token) - - async with aiohttp.ClientSession() as session: - await self.rate_limiter.check() - async with session.request( - method, - f"{self.base_url}/{path}", - params=params, - ) as response: - await self.rate_limiter.add() - try: - response.raise_for_status() - except aiohttp.ClientResponseError as exc: - raise CliDeezerHTTPError.from_status_code(exc) from exc - - json_data = await response.json() - - if not isinstance(json_data, dict): - return json_data - - if "error" in json_data and json_data["error"]: - raise CliDeezerErrorResponse.from_body(json_data) - - return self._process_json( - json_data, - parent=parent, - resource_type=resource_type, - resource_id=resource_id, - paginate_list=paginate_list, - ) diff --git a/src/pyramid/connector/deezer/downloader.py b/src/pyramid/connector/deezer/downloader.py deleted file mode 100644 index a70e261..0000000 --- a/src/pyramid/connector/deezer/downloader.py +++ /dev/null @@ -1,141 +0,0 @@ -import asyncio -import logging -import os -import traceback -from typing import Optional - -import pydeezer.util -from pyramid.connector.deezer.download.client import PyDeezer -from pyramid.connector.deezer.downloader_progress_bar import DownloaderProgressBar -from pyramid.data.music.track import Track -from pyramid.tools.deprecated_class import deprecated_class -from pyramid.tools.generate_token import DeezerTokenProvider -from pydeezer.constants import track_formats -from urllib3.exceptions import MaxRetryError - -from pyramid.data.exceptions import CustomException, DeezerTokensUnavailableException, DeezerTokenInvalidException, DeezerTokenOverflowException - -@deprecated_class -class DeezerDownloader: - def __init__(self, folder: str, arl: Optional[str] = None): - self.folder_path = folder - if arl is not None and arl != "": - self.__arls = [arl] - else: - self.__arls = None - self.__token_provider = DeezerTokenProvider() - self.music_format = track_formats.MP3_128 - os.makedirs(self.folder_path, exist_ok=True) - self.__deezer_dl_api = None - - async def dl_track_by_id(self, track_id) -> Track | None: - client = await self._get_client() - # try: - track_info = await client.get_track_info(track_id) - # except APIRequestError as err: - # logging.warn(f"Unable to download deezer song {track_id} : {err}", exc_info=True) - # return None # Track unvailable in this country - - if not track_info: - logging.error(f"Unable to find deezer song to download {track_id} : Unknown error") - return None - - file_name = pydeezer.util.clean_filename( - f"{track_info['ART_NAME']} - {track_info['SNG_TITLE']}" - ) - file_path = os.path.join(self.folder_path, file_name) + ".mp3" - - if os.path.exists(file_path) is False: - is_dl = await self.__dl_track(track_info, file_name) - if not is_dl: - return None - - track_downloaded = Track(track_info, file_path) - return track_downloaded - - async def __dl_track(self, track_info, file_name: str) -> bool: - try: - client = await self._get_client() - await client.download_track( - track_info, - self.folder_path, - self.music_format, - True, # fallback quality if not available - file_name, - False, # renew track info - False, # metadata - False, # lyrics - ", ", # separator for multiple artists - False, # show messages - DownloaderProgressBar(), # Custom progress bar - ) - return True - except MaxRetryError: - track = Track(track_info, None) - logging.warning("Downloader MaxRetryError %s", track) - await asyncio.sleep(5) - return await self.__dl_track(track_info, file_name) - - except CustomException as error: - trace = "".join(traceback.format_exception(type(error), error, error.__traceback__)) - logging.warning("%s :\n%s", error.args, trace) - return False - - except Exception: - track = Track(track_info, None) - logging.warning("Unable to dl track %s", track, exc_info=True) - return False - - async def _get_client(self) -> PyDeezer: - return await self._define_client() - - async def _define_client(self) -> PyDeezer: - last_err_local = None - if self.__arls: - for arl in self.__arls: - deezer_dl_api = PyDeezer(arl) - try: - await deezer_dl_api.get_user_data() - return deezer_dl_api - except DeezerTokenInvalidException as err: - last_err_local = err - continue - - last_err_remote = None - already_overflow = False - while self.__token_provider.count_valids_tokens() != 0: - try: - token = self.__token_provider.next() - deezer_dl_api = PyDeezer(token.token) - await deezer_dl_api.get_user_data() - return deezer_dl_api - - except DeezerTokenInvalidException as err: - last_err_remote = err - continue - - except DeezerTokenOverflowException as err: - last_err_remote = err - if already_overflow is True: - break - already_overflow = True - self.__token_provider = DeezerTokenProvider() - continue - - except DeezerTokensUnavailableException as err: - last_err_remote = err - break - - if last_err_local is not None: - tb = traceback.TracebackException.from_exception(last_err_local) - formatted_tb = ''.join(tb.format()) - logging.warning("Failed to fetch valid Deezer client from local\n%s", formatted_tb) - - if last_err_remote is not None: - tb = traceback.TracebackException.from_exception(last_err_remote) - formatted_tb = ''.join(tb.format()) - logging.warning("Failed to fetch valid Deezer client from remote\n%s", formatted_tb) - - if last_err_remote is not None and last_err_local is not None: - raise last_err_remote - raise Exception("Unknown func exit") diff --git a/src/pyramid/connector/deezer/search.py b/src/pyramid/connector/deezer/search.py deleted file mode 100644 index b4e08c0..0000000 --- a/src/pyramid/connector/deezer/search.py +++ /dev/null @@ -1,249 +0,0 @@ -import asyncio -import logging -import re -from enum import Enum -import aiohttp - -import deezer -from pyramid.connector.deezer.client.exceptions import CliDeezerNoDataException, CliDeezerRateLimitError -from pyramid.connector.deezer.client.list_paginated import DeezerListPaginated -from pyramid.data.a_search import ASearch, ASearchId -from pyramid.data.music.track_minimal_deezer import TrackMinimalDeezer - -from pyramid.connector.deezer.cli_deezer import CliDeezer -from pyramid.services.deezer_search import DeezerTools, DeezerType -from pyramid.tools.deprecated_class import deprecated_class - -@deprecated_class -class DeezerSearch(ASearchId, ASearch): - def __init__(self, default_limit: int): - self.default_limit = default_limit - self.client = CliDeezer() - self.strict = False - - async def search_track(self, search) -> TrackMinimalDeezer | None: - result = self.client.search(query=search) - track = await result.get_first() - if not track: - return None - return TrackMinimalDeezer(track) - - async def get_track_by_id(self, track_id: int) -> TrackMinimalDeezer | None: - track = await self.client.async_get_track(track_id) # TODO handle HTTP errors - if not track: - return None - return TrackMinimalDeezer(track) - - async def get_track_by_isrc(self, isrc: str) -> TrackMinimalDeezer | None: - try: - track: deezer.Track = await self.client.async_request("GET", f"track/isrc:{isrc}") # type: ignore - if not track: - return None - return TrackMinimalDeezer(track) - except CliDeezerNoDataException: - return None - - async def search_tracks( - self, search, limit: int | None = None - ) -> list[TrackMinimalDeezer] | None: - if limit is None: - limit = self.default_limit - - pagination_results = self.client.search(query=search, strict=self.strict) - tracks = await pagination_results.get_maximum(limit) - if not tracks: - return None - - return [TrackMinimalDeezer(element) for element in tracks] - - async def get_playlist_tracks(self, playlist_name) -> list[TrackMinimalDeezer] | None: - pagination_results = self.client.search_playlists(query=playlist_name, strict=self.strict) - playlist = await pagination_results.get_first() - if not playlist: - return None - pagination_tracks: DeezerListPaginated[deezer.Track] = playlist.get_tracks() # type: ignore - tracks = await pagination_tracks.get_all() - return [TrackMinimalDeezer(element) for element in tracks] - - async def get_playlist_tracks_by_id( - self, playlist_id: int - ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: - playlist = await self.client.async_get_playlist(playlist_id) # TODO handle HTTP errors - if not playlist: - return None - # Tracks id are the not the good one - playlist_tracks: DeezerListPaginated[deezer.Track] = playlist.get_tracks() # type: ignore - - # So we search the id for same name and artist - real_tracks: list[TrackMinimalDeezer] = [] - unfindable_track: list[TrackMinimalDeezer] = [] - - async for chunk_tracks in playlist_tracks: - for t in chunk_tracks: - track = await self.search_exact_track(t.artist.name, t.album.title, t.title) - # logging.info("DEBUG song '%s' - '%s' - '%s'", t.artist.name, t.title, t.album.title) - if track is None: - if not t.readable: - logging.warning( - "Unavailable track in playlist '%s' - '%s'", t.artist.name, t.title - ) - else: - logging.warning( - "Unknown track searched in playlist '%s' - '%s'", t.artist.name, t.title - ) - unfindable_track.append(TrackMinimalDeezer(t)) - continue - real_tracks.append(track) - - return real_tracks, unfindable_track - - async def get_album_tracks(self, album_name) -> list[TrackMinimalDeezer] | None: - pagination_results = self.client.search_albums(query=album_name, strict=self.strict) - album = await pagination_results.get_first() - if not album: - return None - pagination_tracks: DeezerListPaginated[deezer.Track] = album.get_tracks() # type: ignore - tracks = await pagination_tracks.get_all() - return [TrackMinimalDeezer(element) for element in tracks] - - async def get_album_tracks_by_id( - self, album_id: int - ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: - album = await self.client.async_get_album(album_id) # TODO handle HTTP errors - if not album: - return None - pagination_tracks: DeezerListPaginated[deezer.Track] = album.get_tracks() # type: ignore - tracks = await pagination_tracks.get_all() - return [TrackMinimalDeezer(element) for element in tracks], [] - - async def get_top_artist( - self, artist_name, limit: int | None = None - ) -> list[TrackMinimalDeezer] | None: - if limit is None: - limit = self.default_limit - pagination_results = self.client.search_artists(query=artist_name, strict=self.strict) - artist = await pagination_results.get_first() - if not artist: - return None - pagination_tracks: DeezerListPaginated[deezer.Track] = artist.get_top() # type: ignore - tracks = await pagination_tracks.get_maximum(limit) - return [TrackMinimalDeezer(element) for element in tracks] - - async def get_top_artist_by_id( - self, artist_id: int, limit: int | None = None - ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: - if limit is None: - limit = self.default_limit - artist = await self.client.async_get_artist(artist_id) # TODO handle HTTP errors - if not artist: - return None - pagination_tracks: DeezerListPaginated[deezer.Track] = artist.get_top() # type: ignore - tracks = await pagination_tracks.get_maximum(limit) - return [TrackMinimalDeezer(element) for element in tracks], [] - - async def get_by_url( - self, url - ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | TrackMinimalDeezer | None: - id, type = await DeezerTools.extract_from_url(url) - - if id is None: - return None - if type is None: - raise NotImplementedError(f"The type of deezer info '{url}' is not implemented") - - tracks: ( - tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | TrackMinimalDeezer | None - ) - - if type == DeezerType.PLAYLIST: - # future = asyncio.get_event_loop().run_in_executor( - # None, self.get_playlist_tracks_by_id, id - # ) - # tracks = await asyncio.wrap_future(future) - tracks = await self.get_playlist_tracks_by_id(id) - elif type == DeezerType.ARTIST: - tracks = await self.get_top_artist_by_id(id) - elif type == DeezerType.ALBUM: - tracks = await self.get_album_tracks_by_id(id) - elif type == DeezerType.TRACK: - tracks = await self.get_track_by_id(id) - else: - raise NotImplementedError(f"The type of deezer info '{type}' can't be resolve") - - return tracks - - async def search_exact_track( - self, artist_name, album_title, track_title - ) -> TrackMinimalDeezer | None: - clean_artist = self.__remove_special_chars(artist_name) - clean_album = self.__remove_special_chars(album_title) - clean_track = self.__remove_special_chars(track_title) - # logging.info("Song CLEANED '%s' - '%s' - '%s'", clean_artist, clean_track, clean_album) - - track = await self._search_exact_track(clean_artist, clean_album, clean_track) - if track is None: - track = await self._search_exact_track(clean_artist, None, clean_track) - if track is None: - track = await self._search_exact_track(None, clean_album, clean_track) - if track is None: - track = await self._search_exact_track(None, None, clean_track) - # if track is not None: - # logging.warning("Find with title '%s' - '%s' - '%s'", clean_artist, clean_track, clean_album) - # else: - # logging.warning("Find with album & title '%s' - '%s' - '%s'", clean_artist, clean_track, clean_album) - # else: - # logging.warning("Find with artist & title '%s' - '%s' - '%s'", clean_artist, clean_track, clean_album) - return track - - async def _search_exact_track( - self, artist_name, album_title, track_title - ) -> TrackMinimalDeezer | None: - try: - pagination_results = self.client.search( - artist=artist_name, album=album_title, track=track_title - ) - logging.info("_search_exact_track %s - %s - %s", artist_name, album_title, track_title) - track = await pagination_results.get_first() - if track is None: - return None - return TrackMinimalDeezer(track) - - except CliDeezerRateLimitError: - logging.error("Search Deezer RateLimit %s - %s", artist_name, track_title) - await asyncio.sleep(5) - return await self._search_exact_track(artist_name, album_title, track_title) - - def __remove_special_chars( - self, input_string: str | None, allowed_brackets: tuple = ("(", ")", "[", "]") - ): - if input_string is None or input_string == "": - return None - - open_brackets = [b for i, b in enumerate(allowed_brackets) if i % 2 == 0] - close_brackets = [b for i, b in enumerate(allowed_brackets) if i % 2 != 0] - stack: list[str] = [] - result: list[str] = [] - last_char: str | None = None # Keep track of the last processed character - - for char in input_string: - if char in open_brackets: - stack.append(char) - elif char in close_brackets: - if stack: - open_bracket = stack.pop() - if open_brackets.index(open_bracket) == close_brackets.index(char): - continue - if last_char != " ": # Append only if the previous character is not a space - result.append(char) - elif char.isspace(): - if last_char != " ": # Append only if the previous character is not a space - result.append(char) - # elif not stack and (char.isalnum() or char == "'" or char == "/"): - elif not stack: - result.append(char) - else: - continue - last_char = char # Update last_char - - return "".join(result) - diff --git a/src/pyramid/connector/spotify/cli_spotify.py b/src/pyramid/connector/spotify/cli_spotify.py deleted file mode 100644 index 202e981..0000000 --- a/src/pyramid/connector/spotify/cli_spotify.py +++ /dev/null @@ -1,135 +0,0 @@ -import json -import logging -from typing import Any -from venv import logger - -import aiohttp -from spotipy import Spotify -from spotipy.exceptions import SpotifyException - -from pyramid.tools.deprecated_class import deprecated_class - -@deprecated_class -class CliSpotify(Spotify): - - async def async_search(self, q, limit=10, offset=0, type="track", market=None) -> dict[str, Any]: - return await self._get_async( - "search", q=q, limit=limit, offset=offset, type=type, market=market - ) - - async def async_track(self, track_id, market=None) -> dict[str, Any]: - trid = self._get_id("track", track_id) - return await self._get_async("tracks/" + trid, market=market) - - async def async_playlist_items( - self, - playlist_id, - fields=None, - limit=100, - offset=0, - market=None, - additional_types=("track", "episode"), - ) -> dict[str, Any]: - plid = self._get_id("playlist", playlist_id) - return await self._get_async( - "playlists/%s/tracks" % (plid), - limit=limit, - offset=offset, - fields=fields, - market=market, - additional_types=",".join(additional_types), - ) - - async def async_album_tracks(self, album_id, limit=50, offset=0, market=None) -> dict[str, Any]: - trid = self._get_id("album", album_id) - return await self._get_async( - "albums/" + trid + "/tracks/", limit=limit, offset=offset, market=market - ) - - async def async_artist_top_tracks(self, artist_id, country="US") -> dict[str, Any]: - trid = self._get_id("artist", artist_id) - return await self._get_async("artists/" + trid + "/top-tracks", country=country) - - async def async_next(self, result) -> dict[str, Any] | None: - if not result["next"]: - return None - return await self._get_async(result["next"]) - - async def _get_async(self, url, args=None, payload=None, **kwargs) -> dict[str, Any]: - if args: - kwargs.update(args) - - return await self._async_internal_call("GET", url, payload, kwargs) - - async def _async_internal_call(self, method: str, url: str, payload, params) -> dict[str, Any]: - args = dict(params=params) - if not url.startswith("http"): - url = self.prefix + url - headers = self._auth_headers() - - if "content_type" in args["params"]: - headers["Content-Type"] = args["params"]["content_type"] - del args["params"]["content_type"] - if payload: - args["data"] = payload - else: - headers["Content-Type"] = "application/json" - if payload: - args["data"] = json.dumps(payload) - - if self.language is not None: - headers["Accept-Language"] = self.language - - params = ( - {key: value for key, value in args["params"].items() if value is not None} - if "params" in args - else dict() - ) - logging.debug( - "Sending %s to %s with Params: %s Headers: %s and Body: %r ", - method, - url, - params, - headers, - args.get("data"), - ) - async with aiohttp.ClientSession() as session: - async with session.request( - method, - url, - headers=headers, - proxy=self.proxies, - timeout=self.requests_timeout, - params=params, - ) as response: - try: - response.raise_for_status() - results = await response.json() - except aiohttp.ClientResponseError: - try: - json_response = await response.json() - error = json_response.get("error", {}) - msg = error.get("message") - reason = error.get("reason") - except json.JSONDecodeError: - msg = await response.text() or None - reason = None - - logger.error( - "HTTP Error for %s to %s with Params: %s returned %s due to %s", - method, - url, - args.get("params"), - response.status, - msg, - ) - raise SpotifyException( - response.status, - -1, - "%s:\n %s" % (response.url, msg), - reason=reason, - headers=response.headers, - ) - - logger.debug("RESULTS: %s", results) - return results diff --git a/src/pyramid/connector/spotify/search.py b/src/pyramid/connector/spotify/search.py deleted file mode 100644 index 7fb5146..0000000 --- a/src/pyramid/connector/spotify/search.py +++ /dev/null @@ -1,229 +0,0 @@ -from typing import Any - -from pyramid.connector.spotify.spotify_tools import SpotifyTools -from pyramid.connector.spotify.spotify_type import SpotifyType -from pyramid.data.a_search import ASearch, ASearchId -from pyramid.data.music.track_minimal_spotify import TrackMinimalSpotify -from spotipy.oauth2 import SpotifyClientCredentials - -from pyramid.connector.spotify.cli_spotify import CliSpotify -from pyramid.tools.deprecated_class import deprecated_class - -@deprecated_class -class SpotifySearchBase(ASearch): - def __init__(self, default_limit: int, client_id: str, client_secret: str): - self.default_limit = default_limit - self.client_id = client_id - self.client_secret = client_secret - self.client_credentials_manager = SpotifyClientCredentials( - client_id=self.client_id, client_secret=self.client_secret - ) - self.client = CliSpotify(client_credentials_manager=self.client_credentials_manager) - - async def items( - self, results: dict[str, Any], item_name="items" - ) -> None | list[dict[str, Any]]: - if not results: - return None - tracks: list = results[item_name] - - while results["next"]: - results = await self.client.async_next(results) # type: ignore - tracks.extend(results[item_name]) - - return tracks - - async def items_max(self, results: dict[str, Any], limit: int | None = None, item_name="items"): - if not results or not results.get("tracks") or not results["tracks"].get(item_name): - return None - - if limit is None: - limit = self.default_limit - tracks: list[Any] = results["tracks"][item_name] - - results_tracks: dict[str, Any] = results["tracks"] - while results["tracks"]["next"] and limit > len(tracks): - results = await self.client.async_next(results_tracks) # type: ignore - tracks.extend(results_tracks[item_name]) - - if len(tracks) > limit: - return tracks[:limit] - - return tracks - - -@deprecated_class -class SpotifySearchId(ASearchId, SpotifySearchBase): - def __init__(self, default_limit: int, client_id: str, client_secret: str): - super().__init__(default_limit, client_id, client_secret) - - async def get_track_by_id(self, track_id: str) -> TrackMinimalSpotify | None: - result = await self.client.async_track(track_id=track_id) - if not result: - return None - return TrackMinimalSpotify(result) - - async def get_playlist_tracks_by_id( - self, playlist_id: str - ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | None: - tracks_playlist = await self.items( - await self.client.async_playlist_items(playlist_id=playlist_id) - ) - if not tracks_playlist: - return None - return [TrackMinimalSpotify(element["track"]) for element in tracks_playlist], [] - - async def get_album_tracks_by_id( - self, album_id: str - ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | None: - tracks = await self.items(await self.client.async_album_tracks(album_id=album_id)) - if not tracks: - return None - - readable_tracks = [] - unreadable_tracks = [] - for t in tracks: - track = await self.get_track_by_id(t["id"]) - if track is None: - unreadable_tracks.append(t) - else: - readable_tracks.append(track) - return readable_tracks, unreadable_tracks - - async def get_top_artist_by_id( - self, artist_id: str, limit: int | None = None - ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | None: - if limit is None: - limit = self.default_limit - results = await self.client.async_artist_top_tracks(artist_id) - - if not results or not results.get("tracks"): - return None - - tracks = results["tracks"] - if len(tracks) > limit: - tracks = tracks[:limit] - return [TrackMinimalSpotify(element) for element in tracks], [] - -@deprecated_class -class SpotifyResponse: - def __init__(self, client: CliSpotify, default_limit: int, item_name="items") -> None: - self.client = client - self.default_limit = default_limit - self.item_name = item_name - - async def items(self, results: dict[str, Any], limit: int | None = None): - if not results or not results.get("tracks") or not results["tracks"].get(self.item_name): - return None - - if limit is None: - limit = self.default_limit - tracks: list[Any] = results["tracks"][self.item_name] - - results_tracks: dict[str, Any] = results["tracks"] - while results["tracks"]["next"] and limit > len(tracks): - results = await self.client.async_next(results_tracks) # type: ignore - tracks.extend(results_tracks[self.item_name]) - - if len(tracks) > limit: - return tracks[:limit] - - return tracks - - -@deprecated_class -class SpotifySearch(SpotifySearchId): - def __init__(self, default_limit: int, client_id: str, client_secret: str): - super().__init__(default_limit, client_id, client_secret) - - async def search_tracks( - self, search, limit: int | None = None - ) -> list[TrackMinimalSpotify] | None: - if limit is None: - limit = self.default_limit - if limit > 50: - req_limit = 50 - else: - req_limit = limit - results = await self.client.async_search(q=search, limit=req_limit, type="track") - tracks = await self.items_max(results, limit) - if not tracks: - return None - return [TrackMinimalSpotify(element) for element in tracks] - - async def search_track(self, search) -> TrackMinimalSpotify | None: - results = await self.client.async_search(q=search, limit=1, type="track") - - if not results or not results.get("tracks") or not results["tracks"].get("items"): - return None - - tracks = results["tracks"]["items"] - track = tracks[0] - - return TrackMinimalSpotify(track) - - async def get_playlist_tracks(self, playlist_name) -> list[TrackMinimalSpotify] | None: - results = await self.client.async_search(q=playlist_name, limit=1, type="playlist") - - if not results or not results.get("playlists") or not results["playlists"].get("items"): - return None - - playlist_id = results["playlists"]["items"][0]["id"] - tracks = await self.get_playlist_tracks_by_id(playlist_id) - if not tracks: - return None - return tracks[0] - - async def get_album_tracks(self, album_name) -> list[TrackMinimalSpotify] | None: - results = await self.client.async_search(q=album_name, limit=1, type="album") - - if not results or not results.get("albums") or not results["albums"].get("items"): - return None - - album_id = results["albums"]["items"][0]["id"] - tracks = await self.get_album_tracks_by_id(album_id) - if not tracks: - return None - return tracks[0] - - async def get_top_artist( - self, artist_name, limit: int | None = None - ) -> list[TrackMinimalSpotify] | None: - results = await self.client.async_search(q=artist_name, limit=1, type="artist") - - if not results or not results.get("artists") or not results["artists"].get("items"): - return None - - artist_id = results["artists"]["items"][0]["id"] - tracks = await self.get_top_artist_by_id(artist_id) - if not tracks: - return None - return tracks[0] - - async def get_by_url( - self, url - ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | TrackMinimalSpotify | None: - id, type = SpotifyTools.extract_from_url(url) - - if id is None: - return None - if type is None: - raise NotImplementedError(f"The type of spotify info '{url}' is not implemented") - - tracks: ( - tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | TrackMinimalSpotify | None - ) - - if type == SpotifyType.PLAYLIST: - tracks = await self.get_playlist_tracks_by_id(id) - elif type == SpotifyType.ARTIST: - tracks = await self.get_top_artist_by_id(id) - elif type == SpotifyType.ALBUM: - tracks = await self.get_album_tracks_by_id(id) - elif type == SpotifyType.TRACK: - tracks = await self.get_track_by_id(id) - else: - raise NotImplementedError(f"The type of spotify info '{type}' can't be resolve") - - return tracks - diff --git a/src/pyramid/services/builder/configuration.py b/src/pyramid/services/builder/configuration.py new file mode 100644 index 0000000..1944b8c --- /dev/null +++ b/src/pyramid/services/builder/configuration.py @@ -0,0 +1,48 @@ +from typing import Self +from pyramid.api.services.configuration import IConfigurationService +from pyramid.data.environment import Environment +from pyramid.services.configuration import ConfigurationService + +class ConfigurationBuilder(): + + def __init__(self): + self.service = ConfigurationService() + + def discord_token(self, discord_token: str) -> Self: + self.service.discord__token = discord_token + return self + + def discord_ffmpeg(self, discord_ffmpeg: str) -> Self: + self.service.discord__ffmpeg = discord_ffmpeg + return self + + def deezer_arl(self, deezer_arl: str) -> Self: + self.service.deezer__arl = deezer_arl + return self + + def deezer_folder(self, deezer_folder: str) -> Self: + self.service.deezer__folder = deezer_folder + return self + + def spotify_client_id(self, spotify_client_id: str) -> Self: + self.service.spotify__client_id = spotify_client_id + return self + + def spotify_client_secret(self, spotify_client_secret: str) -> Self: + self.service.spotify__client_secret = spotify_client_secret + return self + + def general_limit_tracks(self, limit_tracks: int) -> Self: + self.service.general__limit_tracks = limit_tracks + return self + + def mode(self, mode: Environment) -> Self: + self.service.mode = mode + return self + + def version(self, version: str) -> Self: + self.service.version = version + return self + + def build(self) -> IConfigurationService: + return self.service diff --git a/src/pyramid/services/configuration.py b/src/pyramid/services/configuration.py index 47413fa..9435119 100644 --- a/src/pyramid/services/configuration.py +++ b/src/pyramid/services/configuration.py @@ -18,15 +18,6 @@ def start(self): self.load() def load(self, config_file: str = "config.yml", use_env_vars: bool = True) -> bool: - """ - Loads configuration values from environment variables and/or a configuration file. - - Parameters: - - use_env_vars (bool): If True, loads configuration values from environment variables. - - Returns: - - bool: True if the loading process is successful, False otherwise. - """ keys_length = utils.count_public_variables(self) # Load from environment variables if enabled @@ -55,11 +46,5 @@ def load(self, config_file: str = "config.yml", use_env_vars: bool = True) -> bo return result_1 and result_2 - def save(self, file_name): - """ - Saves the configuration values to a YAML file. - - Parameters: - - file_name (str): The name of the file to which the configuration will be saved. - """ + def save(self, file_name: str): self._save_to_yaml(file_name) diff --git a/src/pyramid/services/deezer_downloader.py b/src/pyramid/services/deezer_downloader.py index e6fe102..38e4548 100644 --- a/src/pyramid/services/deezer_downloader.py +++ b/src/pyramid/services/deezer_downloader.py @@ -1,4 +1,5 @@ import asyncio +import logging import os import traceback from typing import Any @@ -11,12 +12,12 @@ from pyramid.connector.deezer.downloader_progress_bar import DownloaderProgressBar from pyramid.connector.deezer.download.client import PyDeezer from pyramid.data.music.track import Track -from pyramid.tools.generate_token import DeezerTokenProvider, DeezerTokenEmptyException, DeezerTokenOverflowException +from pyramid.tools.generate_token import DeezerTokenProvider from pydeezer.constants import track_formats from pydeezer.exceptions import LoginError from urllib3.exceptions import MaxRetryError -from pyramid.data.exceptions import CustomException +from pyramid.data.exceptions import CustomException, DeezerTokenInvalidException, DeezerTokenOverflowException, DeezerTokensUnavailableException @pyramid_service(interface=IDeezerDownloaderService) @@ -30,25 +31,13 @@ def injectService(self, self.__configuration_service = configuration_service def start(self): - # arl = self.__configuration_service.deezer__arl - arl = None + arl = self.__configuration_service.deezer__arl if arl is not None and arl != "": - self.__deezer_dl_api = PyDeezer(arl) - self.__token_provider = None + self.__arls = [arl] else: - self.__deezer_dl_api = None - self.__token_provider = DeezerTokenProvider() + self.__arls = None + self.__token_provider = DeezerTokenProvider() self.music_format = track_formats.MP3_128 - os.makedirs(self.__configuration_service.deezer__folder, exist_ok=True) - - async def check_credentials(self) -> dict[str, Any]: - if not self.__deezer_dl_api: - raise Exception("deezer_dl_api not init") - try: - await self.__deezer_dl_api.get_user_data() - return self.__deezer_dl_api.user - except LoginError as err: - raise err # Arl is invalid async def dl_track_by_id(self, track_id) -> Track | None: client = await self._get_client() @@ -100,7 +89,7 @@ async def __dl_track(self, track_info, file_name: str) -> bool: except CustomException as error: trace = "".join(traceback.format_exception(type(error), error, error.__traceback__)) - self.__logger.warning("%s :\n%s", error.msg, trace) + self.__logger.warning("%s :\n%s", error.args, trace) return False except Exception: @@ -109,34 +98,55 @@ async def __dl_track(self, track_info, file_name: str) -> bool: return False async def _get_client(self) -> PyDeezer: - i = 0 - max_error = 10 - - if self.__deezer_dl_api: - return self.__deezer_dl_api - if not self.__token_provider: - raise Exception("token_provider not init") - - while True: + return await self._define_client() + + async def _define_client(self) -> PyDeezer: + last_err_local = None + if self.__arls: + for arl in self.__arls: + deezer_dl_api = PyDeezer(arl) + try: + await deezer_dl_api.get_user_data() + return deezer_dl_api + except DeezerTokenInvalidException as err: + last_err_local = err + continue + + last_err_remote = None + already_overflow = False + while self.__token_provider.count_valids_tokens() != 0: try: token = self.__token_provider.next() - self.__deezer_dl_api = PyDeezer(token.token) - await self.check_credentials() - break + deezer_dl_api = PyDeezer(token.token) + await deezer_dl_api.get_user_data() + return deezer_dl_api - except DeezerTokenEmptyException as err: - if i > max_error: - raise err - self.__token_provider = DeezerTokenProvider() + except DeezerTokenInvalidException as err: + last_err_remote = err + continue except DeezerTokenOverflowException as err: - if i > max_error: - raise err + last_err_remote = err + if already_overflow is True: + break + already_overflow = True self.__token_provider = DeezerTokenProvider() + continue + + except DeezerTokensUnavailableException as err: + last_err_remote = err + break + + if last_err_local is not None: + tb = traceback.TracebackException.from_exception(last_err_local) + formatted_tb = ''.join(tb.format()) + logging.warning("Failed to fetch valid Deezer client from local\n%s", formatted_tb) - except LoginError as err: - if i > max_error: - raise err - i += 1 + if last_err_remote is not None: + tb = traceback.TracebackException.from_exception(last_err_remote) + formatted_tb = ''.join(tb.format()) + logging.warning("Failed to fetch valid Deezer client from remote\n%s", formatted_tb) - return self.__deezer_dl_api + if last_err_remote is not None and last_err_local is not None: + raise last_err_remote + raise Exception("Unknown func exit") diff --git a/src/pyramid/services/spotify_client.py b/src/pyramid/services/spotify_client.py index 321d68f..8047d29 100644 --- a/src/pyramid/services/spotify_client.py +++ b/src/pyramid/services/spotify_client.py @@ -147,5 +147,4 @@ async def _async_internal_call(self, method: str, url: str, payload, params) -> headers=response.headers, ) - logger.debug("RESULTS: %s", results) return results diff --git a/src/pyramid/services/spotify_search_base.py b/src/pyramid/services/spotify_search_base.py index db5ab28..f740721 100644 --- a/src/pyramid/services/spotify_search_base.py +++ b/src/pyramid/services/spotify_search_base.py @@ -8,8 +8,6 @@ from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector -from pyramid.connector.spotify.cli_spotify import CliSpotify - @pyramid_service(interface=ISpotifySearchBaseService) class SpotifySearchBaseService(ISpotifySearchBaseService, ServiceInjector): @@ -20,9 +18,6 @@ def injectService(self, self.__configuration_service = configuration_service self.__spotify_client = spotify_client - # def start(self): - # self.__spotify_client = CliSpotify(client_credentials_manager=self.client_credentials_manager) - async def items( self, results: dict[str, Any], diff --git a/src/pyramid/tools/configuration/configuration.py b/src/pyramid/tools/configuration/configuration.py deleted file mode 100644 index b824e01..0000000 --- a/src/pyramid/tools/configuration/configuration.py +++ /dev/null @@ -1,66 +0,0 @@ -from pyramid.tools import utils -from pyramid.tools.deprecated_class import deprecated_class -from pyramid.data.environment import Environment -from pyramid.tools.configuration.configuration_load import ConfigurationFromEnv, ConfigurationFromYAML -from pyramid.tools.configuration.configuration_save import ConfigurationToYAML - - -@deprecated_class -class Configuration(ConfigurationFromYAML, ConfigurationToYAML, ConfigurationFromEnv): - def __init__(self): - self.discord__token: str = "" - self.discord__ffmpeg: str = "" - self.deezer__arl: str = "" - self.deezer__folder: str = "" - self.spotify__client_id: str = "" - self.spotify__client_secret: str = "" - self.general__limit_tracks: int = 0 - self.mode: Environment = Environment.PRODUCTION - self.version: str = "" - - def load(self, config_file: str = "config.yml", use_env_vars: bool = True) -> bool: - """ - Loads configuration values from environment variables and/or a configuration file. - - Parameters: - - use_env_vars (bool): If True, loads configuration values from environment variables. - - Returns: - - bool: True if the loading process is successful, False otherwise. - """ - keys_length = utils.count_public_variables(self) - - # Load from environment variables if enabled - result_1 = True - if use_env_vars: - raw_values_env = self._get_env_vars() - result_values = self._transform_all(raw_values_env, keys_length) - result_1 = self._validate_all( - raw_values_env, result_values, "env vars", True, keys_length - ) - - # Load raw values from environment variables and config file - try: - raw_values_file = self._get_file_vars(config_file) - result_values = self._transform_all(raw_values_file, keys_length) - result_2 = self._validate_all( - raw_values_file, result_values, "file", keys_length=keys_length - ) - except FileNotFoundError as err: - if not result_1: - self.logger.critical( - "Unable to read configuration file '%s' :\n%s", config_file, err - ) - return False - result_2 = True - - return result_1 and result_2 - - def save(self, file_name): - """ - Saves the configuration values to a YAML file. - - Parameters: - - file_name (str): The name of the file to which the configuration will be saved. - """ - self._save_to_yaml(file_name) diff --git a/src/pyramid/tools/logs_handler.py b/src/pyramid/tools/logs_handler.py deleted file mode 100644 index 8c61334..0000000 --- a/src/pyramid/tools/logs_handler.py +++ /dev/null @@ -1,89 +0,0 @@ -import logging -import logging.handlers -import os -import sys - -import coloredlogs -from pyramid.tools import utils -from pyramid.data.functional.application_info import ApplicationInfo -from pyramid.data.environment import Environment -from pyramid.tools.deprecated_class import deprecated_class - - -@deprecated_class -class LogsHandler: - def __init__(self): - self.__date = "%d/%m/%Y %H:%M:%S" - self.__console_format = "%(asctime)s %(levelname)s %(message)s" - self.__file_format = "[{asctime}] [{levelname:<8}] {name}: {message}" - - def init(self, info: ApplicationInfo, logs_dir: str, log_filename: str, error_filename: str): - self.__info = info - self.__logs_dir = logs_dir - self.__log_filename = log_filename - self.__error_filename = error_filename - - self.logger = logging.getLogger() - - self.log_to_console() - self.log_to_file() - self.log_to_file_exceptions() - - def log_to_console(self): - coloredlogs.install(fmt=self.__console_format, datefmt=self.__date, isatty=True) - - def log_to_file(self): - log_filename = os.path.join(self.__logs_dir, self.__log_filename) - utils.create_parent_directories(log_filename) - - file_handler = logging.handlers.RotatingFileHandler( - filename=log_filename, - encoding="utf-8", - maxBytes=512 * 1024 * 1024, # 512 Mo - ) - - formatter = logging.Formatter(self.__file_format, self.__date, style="{") - file_handler.setFormatter(formatter) - - self.logger.addHandler(file_handler) - - def log_to_file_exceptions(self): - log_filename = os.path.join(self.__logs_dir, self.__error_filename) - utils.create_parent_directories(log_filename) - - file_handler = logging.handlers.RotatingFileHandler( - filename=log_filename, - encoding="utf-8", - maxBytes=10 * 1024 * 1024, # 10 Mo - backupCount=10, - ) - - formatter = logging.Formatter(self.__file_format, self.__date, style="{") - file_handler.setFormatter(formatter) - - # Retrieves warning exceptions and above - file_handler.setLevel("WARNING") - logging.getLogger().addHandler(file_handler) - - # Retrieves unhandled exceptions - self.logger_unhandled_exception = logging.getLogger("Unhandled Exception") - self.logger_unhandled_exception.addHandler(file_handler) - sys.excepthook = self.__handle_unhandled_exception - - def __handle_unhandled_exception(self, exc_type, exc_value, exc_traceback): - if issubclass(exc_type, KeyboardInterrupt): - # Will call default excepthook - sys.__excepthook__(exc_type, exc_value, exc_traceback) - return - # Create a critical level log message with info from the except hook. - self.logger_unhandled_exception.critical( - self.__info, exc_info=(exc_type, exc_value, exc_traceback) - ) - - def set_log_level(self, mode: Environment): - if mode == Environment.PRODUCTION: - self.logger.setLevel("INFO") - else: - self.logger.setLevel("DEBUG") - # coloredlogs.set_level("DEBUG") - coloredlogs.set_level("INFO") diff --git a/tests/config_test.py b/tests/config_test.py index 07f4f7b..e63d6b3 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -2,15 +2,13 @@ import unittest from unittest.mock import patch -from pyramid.tools.configuration.configuration import Configuration +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.tools.tester import ServiceStandalone class ConfigurationTest(unittest.TestCase): def setUp(self): - self.config = Configuration() - - def test_default_logger(self): - self.assertIsNotNone(logging.getLogger("config")) - self.assertEqual(logging.getLogger("config").name, "config") + ServiceStandalone.import_services() + self.config = ServiceStandalone.get_service(IConfigurationService) @patch("pyramid.tools.configuration.configuration_load.ConfigurationFromEnv._get_env_vars") @patch("pyramid.tools.configuration.configuration_load.ConfigurationFromYAML._get_file_vars") diff --git a/tests/deezer_download_test.py b/tests/deezer_download_test.py index 4712714..f9fa22c 100644 --- a/tests/deezer_download_test.py +++ b/tests/deezer_download_test.py @@ -2,18 +2,26 @@ import shutil import unittest -from pyramid.connector.deezer.downloader import DeezerDownloader +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.deezer_downloader import IDeezerDownloaderService +from pyramid.api.services.tools.tester import ServiceStandalone +from pyramid.services.builder.configuration import ConfigurationBuilder class DeezerDownloadTest(unittest.IsolatedAsyncioTestCase): def __init__(self, methodName: str = "runTest") -> None: super().__init__(methodName) - arl = os.getenv("DEEZER__ARL") - self.path = "./test-songs" - self.cli = DeezerDownloader(self.path, arl) + self.path = "./tests-songs" + os.makedirs(self.path, exist_ok=True) + + ServiceStandalone.import_services() + builder = ConfigurationBuilder().deezer_folder(self.path).deezer_arl(os.getenv("DEEZER__ARL") or "") + ServiceStandalone.set_service(IConfigurationService, builder.build()) + + self.deezer_downloader = ServiceStandalone.get_service(IDeezerDownloaderService) async def test_download(self): - track = await self.cli.dl_track_by_id(2308590) + track = await self.deezer_downloader.dl_track_by_id(2308590) self.assertIsNotNone(track) def tearDown(self): diff --git a/tests/deezer_search_test.py b/tests/deezer_search_test.py index 038ad62..4d9c0ae 100644 --- a/tests/deezer_search_test.py +++ b/tests/deezer_search_test.py @@ -1,66 +1,70 @@ -import os import unittest -from pyramid.connector.deezer.downloader import DeezerDownloader -from pyramid.connector.deezer.search import DeezerSearch +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.deezer_search import IDeezerSearchService +from pyramid.api.services.tools.tester import ServiceStandalone +from pyramid.services.builder.configuration import ConfigurationBuilder class DeezerSearchTest(unittest.IsolatedAsyncioTestCase): def __init__(self, methodName: str = "runTest") -> None: super().__init__(methodName) - limit = int(os.getenv("GENERAL__LIMIT_TRACKS", 100)) - self.cli = DeezerSearch(limit) + ServiceStandalone.import_services() + builder = ConfigurationBuilder().general_limit_tracks(100) + ServiceStandalone.set_service(IConfigurationService, builder.build()) + + self.deezer_search = ServiceStandalone.get_service(IDeezerSearchService) async def test_search_single(self): - track = await self.cli.search_track("Johnny Hallyday - Allumer le feu") + track = await self.deezer_search.search_track("Johnny Hallyday - Allumer le feu") self.assertIsNotNone(track) async def test_search_multiple(self): - tracks = await self.cli.search_tracks("Johnny Hallyday") + tracks = await self.deezer_search.search_tracks("Johnny Hallyday") self.assertIsNotNone(tracks) async def test_search_playlist(self): - tracks = await self.cli.get_playlist_tracks("Best of Johnny Hallyday - live") + tracks = await self.deezer_search.get_playlist_tracks("Best of Johnny Hallyday - live") self.assertIsNotNone(tracks) async def test_search_album(self): - tracks = await self.cli.get_album_tracks("Anthologie Vol. 1") + tracks = await self.deezer_search.get_album_tracks("Anthologie Vol. 1") self.assertIsNotNone(tracks) async def test_search_top(self): - tracks = await self.cli.get_top_artist("Johnny Hallyday") + tracks = await self.deezer_search.get_top_artist("Johnny Hallyday") self.assertIsNotNone(tracks) async def test_url_artist(self): - tracks = await self.cli.get_by_url("https://www.deezer.com/fr/artist/1060") + tracks = await self.deezer_search.get_by_url("https://www.deezer.com/fr/artist/1060") self.assertIsNotNone(tracks) # async def test_url_artist_2nd_format(self): - # tracks = await self.cli.get_by_url("https://deezer.page.link/HWapYqfpsmSukE6T7") + # tracks = await self.deezer_search.get_by_url("https://deezer.page.link/HWapYqfpsmSukE6T7") # self.assertIsNotNone(tracks) async def test_url_album(self): - tracks = await self.cli.get_by_url("https://www.deezer.com/fr/album/53012892") + tracks = await self.deezer_search.get_by_url("https://www.deezer.com/fr/album/53012892") self.assertIsNotNone(tracks) # async def test_url_album_2nd_format(self): - # tracks = await self.cli.get_by_url("https://deezer.page.link/gvryHN1VUn62CnCJ7") + # tracks = await self.deezer_search.get_by_url("https://deezer.page.link/gvryHN1VUn62CnCJ7") # self.assertIsNotNone(tracks) async def test_url_track(self): - tracks = await self.cli.get_by_url("https://www.deezer.com/fr/track/2308590") + tracks = await self.deezer_search.get_by_url("https://www.deezer.com/fr/track/2308590") self.assertIsNotNone(tracks) # async def test_url_track_2nd_format(self): - # tracks = await self.cli.get_by_url("https://deezer.page.link/qF6ucYP2wSGsLiMB6") + # tracks = await self.deezer_search.get_by_url("https://deezer.page.link/qF6ucYP2wSGsLiMB6") # self.assertIsNotNone(tracks) async def test_url_playlist(self): - tracks = await self.cli.get_by_url("https://www.deezer.com/fr/playlist/987181371") + tracks = await self.deezer_search.get_by_url("https://www.deezer.com/fr/playlist/987181371") self.assertIsNotNone(tracks) # async def test_url_playlist_2nd_format(self): - # tracks = await self.cli.get_by_url("https://deezer.page.link/ibwojNjEKAQjsKgZ9") + # tracks = await self.deezer_search.get_by_url("https://deezer.page.link/ibwojNjEKAQjsKgZ9") # self.assertIsNotNone(tracks) if __name__ == "__main__": diff --git a/tests/spotify_test.py b/tests/spotify_test.py index 0b1b9cf..aa36492 100644 --- a/tests/spotify_test.py +++ b/tests/spotify_test.py @@ -1,64 +1,71 @@ import os import unittest -from pyramid.connector.spotify.search import SpotifySearch +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.spotify_search import ISpotifySearchService +from pyramid.api.services.tools.tester import ServiceStandalone +from pyramid.services.builder.configuration import ConfigurationBuilder class SpotifySearchTest(unittest.IsolatedAsyncioTestCase): def __init__(self, methodName: str = "runTest") -> None: super().__init__(methodName) - - limit = int(os.getenv("GENERAL__LIMIT_TRACKS", 100)) - client_id = os.getenv("SPOTIFY__CLIENT_ID", "c760faf4271a4e30b0f5a1a228b0a8d6") - client_secret = os.getenv("SPOTIFY__CLIENT_SECRET", "a122b8f9fb5741b99e5ce88805bebe35") - self.cli = SpotifySearch(limit, client_id, client_secret) + + ServiceStandalone.import_services() + builder = ConfigurationBuilder() + builder.general_limit_tracks(100) + builder.spotify_client_id("9a84b8a2093f4ca9ae0aa1dc2f184b3f") + builder.spotify_client_secret("4c575d7b7e5e47ee90fcc143672c800b") + ServiceStandalone.set_service(IConfigurationService, builder.build()) + + self.spotify_search = ServiceStandalone.get_service(ISpotifySearchService) async def test_search_single(self): - track = await self.cli.search_track("Johnny Hallyday - Allumer le feu") + track = await self.spotify_search.search_track("Johnny Hallyday - Allumer le feu") self.assertIsNotNone(track) async def test_search_multiple(self): - tracks = await self.cli.search_tracks("Johnny Hallyday") + tracks = await self.spotify_search.search_tracks("Johnny Hallyday") self.assertIsNotNone(tracks) async def test_search_multiple_limit(self): limit = 75 - tracks = await self.cli.search_tracks("Johnny Hallyday", limit) + tracks = await self.spotify_search.search_tracks("Johnny Hallyday", limit) size = len(tracks) if tracks is not None else 0 self.assertEqual(limit, size) async def test_search_playlist(self): - tracks = await self.cli.get_playlist_tracks("Best of Johnny Hallyday - live") + tracks = await self.spotify_search.get_playlist_tracks("Best of Johnny Hallyday - live") self.assertIsNotNone(tracks) async def test_search_album(self): - tracks = await self.cli.get_album_tracks("Anthologie Vol. 1") + tracks = await self.spotify_search.get_album_tracks("Anthologie Vol. 1") self.assertIsNotNone(tracks) async def test_search_top(self): - tracks = await self.cli.get_top_artist("Johnny Hallyday") + tracks = await self.spotify_search.get_top_artist("Johnny Hallyday") self.assertIsNotNone(tracks) async def test_url_artist(self): - tracks = await self.cli.get_by_url( + tracks = await self.spotify_search.get_by_url( "https://open.spotify.com/intl-fr/artist/2HALYSe657tNJ1iKVXP2xA?si=e5x_arTGSqWtnRrJjCCTdQ" ) self.assertIsNotNone(tracks) async def test_url_album(self): - tracks = await self.cli.get_by_url( + tracks = await self.spotify_search.get_by_url( "https://open.spotify.com/album/1mhVZVEHbIYUN7b5DkXF7d?si=U9HaPDJtRdSnT9qZaQxjXA" ) self.assertIsNotNone(tracks) async def test_url_track(self): - tracks = await self.cli.get_by_url( + tracks = await self.spotify_search.get_by_url( "https://open.spotify.com/intl-fr/track/1mzZP8UA2RZUXDw33QNmn4?si=4c27f893cd7c4055" ) self.assertIsNotNone(tracks) async def test_url_playlist(self): - tracks = await self.cli.get_by_url( + tracks = await self.spotify_search.get_by_url( "https://open.spotify.com/playlist/37i9dQZF1DZ06evO1ymAtQ?si=5c4b6ff36a0d4c98" ) self.assertIsNotNone(tracks) From f5cb7202e6363e8c944fdaff0c61139e5c0666fc Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Fri, 27 Sep 2024 00:51:52 +0200 Subject: [PATCH 15/32] feat: socket server as a service --- .docker/Dockerfile | 6 +- .vscode/launch.json | 16 +++ Makefile | 16 ++- docker-compose.yml | 1 + src/pyramid/api/services/__init__.py | 1 - src/pyramid/api/services/socket_server.py | 12 +- src/pyramid/api/services/tools/register.py | 13 +- src/pyramid/api/tasks/tools/register.py | 1 + src/{cli.py => pyramid/cli/startup.py} | 12 +- src/pyramid/client/client.py | 16 +-- src/pyramid/client/common.py | 9 +- src/pyramid/client/requests/a_request.py | 4 +- src/pyramid/client/requests/health.py | 27 ---- src/pyramid/client/requests/ping.py | 26 ++++ src/pyramid/client/responses/a_response.py | 17 ++- .../client/responses/a_response_header.py | 12 ++ src/pyramid/client/server.py | 98 -------------- src/pyramid/data/functional/main.py | 4 - src/pyramid/data/health.py | 7 - src/pyramid/data/ping.py | 6 + src/pyramid/services/configuration.py | 2 +- src/pyramid/services/logger.py | 4 + src/pyramid/services/socket_server.py | 120 +++++++++++++++--- src/pyramid/tasks/discord.py | 7 +- src/pyramid/tasks/socket_server.py | 17 +++ src/pyramid/tools/custom_queue.py | 2 +- src/startup_cli.py | 4 + src/startup_cli_dev.py | 10 ++ src/{dev.py => startup_dev.py} | 0 29 files changed, 266 insertions(+), 204 deletions(-) rename src/{cli.py => pyramid/cli/startup.py} (87%) delete mode 100644 src/pyramid/client/requests/health.py create mode 100644 src/pyramid/client/requests/ping.py create mode 100644 src/pyramid/client/responses/a_response_header.py delete mode 100644 src/pyramid/client/server.py delete mode 100644 src/pyramid/data/health.py create mode 100644 src/pyramid/data/ping.py create mode 100644 src/pyramid/tasks/socket_server.py create mode 100644 src/startup_cli.py create mode 100644 src/startup_cli_dev.py rename src/{dev.py => startup_dev.py} (100%) diff --git a/.docker/Dockerfile b/.docker/Dockerfile index cf07ec9..9156fdd 100644 --- a/.docker/Dockerfile +++ b/.docker/Dockerfile @@ -41,7 +41,7 @@ ARG APP_GROUP LABEL org.opencontainers.image.source="https://github.com/tristiisch/PyRamid" \ org.opencontainers.image.authors="tristiisch" -HEALTHCHECK --interval=30s --retries=3 --timeout=30s CMD python ./src/cli.py health +# HEALTHCHECK --interval=30s --retries=3 --timeout=30s CMD python ./src/startup_cli.py health # Expose port for health check EXPOSE 49150 @@ -100,6 +100,8 @@ ARG APP_GROUP ARG PROJECT_VERSION ENV PROJECT_VERSION=$PROJECT_VERSION +# HEALTHCHECK --interval=30s --retries=3 --timeout=30s CMD python -Xfrozen_modules=off ./src/startup_cli_dev.py health + COPY --chown=root:$APP_GROUP --chmod=550 --from=builder-dev /opt/venv /opt/venv ENV PATH="/opt/venv/bin:$PATH" @@ -107,7 +109,7 @@ COPY --chown=root:$APP_GROUP --chmod=750 ./src ./src USER $APP_USER -CMD ["python", "-Xfrozen_modules=off", "./src/dev.py"] +CMD ["python", "-Xfrozen_modules=off", "./src/startup_dev.py"] # ============================ Test Image ============================ FROM base AS tests diff --git a/.vscode/launch.json b/.vscode/launch.json index 8ad9737..92e4591 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -35,6 +35,22 @@ } ], "justMyCode": false + }, + { + "name": "Python: Remote Attach CLI", + "type": "python", + "request": "attach", + "connect": { + "host": "localhost", + "port": 5679 + }, + "pathMappings": [ + { + "localRoot": "${workspaceFolder}", + "remoteRoot": "/app" + } + ], + "justMyCode": false } ] } diff --git a/Makefile b/Makefile index c020e50..e761241 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ DOCKER_COMPOSE_FILE_PREPROD := ./.docker/docker-compose.preprod.yml DOCKER_SERVICE_PREPROD := pyramid_preprod_pyramid DOCKER_CONTEXT_PREPROD := cookie-pulsheberg -.PHONY: logs +# Basics all: up-b logs @@ -45,6 +45,8 @@ logs: exec: @docker compose exec $(COMPOSE_SERVICE) sh +# Other envs + exec-pp: @scripts/docker_service_exec.sh $(DOCKER_SERVICE_PREPROD) $(DOCKER_CONTEXT_PREPROD) @@ -56,6 +58,14 @@ tests: @mkdir -p ./coverage && chmod 777 ./coverage @docker run --rm --env-file ./.env -v ./coverage:/app/coverage -it pyramid:tests +healthcheck: + @docker compose exec $(COMPOSE_SERVICE) sh -c "python ./src/startup_cli.py health" + +healthcheck-dev: + @docker compose exec $(COMPOSE_SERVICE) sh -c "python -Xfrozen_modules=off ./src/startup_cli_dev.py health" + +# Pythons scripts + img-b: @python scripts/environnement.py --build @@ -68,4 +78,6 @@ img-c: clean: @python scripts/environnement.py --clean -.PHONY: build tests \ No newline at end of file +# Other + +.PHONY: build tests logs diff --git a/docker-compose.yml b/docker-compose.yml index 10534c0..83c817c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,6 +13,7 @@ services: - pyramid_network ports: - 5678:5678 + - 5679:5679 env_file: .env networks: diff --git a/src/pyramid/api/services/__init__.py b/src/pyramid/api/services/__init__.py index 84fbfed..fc1e4f4 100644 --- a/src/pyramid/api/services/__init__.py +++ b/src/pyramid/api/services/__init__.py @@ -2,5 +2,4 @@ from .discord import IDiscordService from .information import IInformationService from .logger import ILoggerService -from .socket_server import ISocketServerService # from .source_service import ISourceService diff --git a/src/pyramid/api/services/socket_server.py b/src/pyramid/api/services/socket_server.py index eb35498..06961ac 100644 --- a/src/pyramid/api/services/socket_server.py +++ b/src/pyramid/api/services/socket_server.py @@ -1,8 +1,12 @@ -from abc import ABC, abstractmethod -from pyramid.data.functional.application_info import ApplicationInfo +from abc import abstractmethod -class ISocketServerService(ABC): + +class ISocketServerService: + + @abstractmethod + async def open(self): + pass @abstractmethod - def start(self): + def close(self): pass diff --git a/src/pyramid/api/services/tools/register.py b/src/pyramid/api/services/tools/register.py index 3932d65..dbae533 100644 --- a/src/pyramid/api/services/tools/register.py +++ b/src/pyramid/api/services/tools/register.py @@ -24,16 +24,17 @@ def import_services(cls): importlib.import_module(full_module_name) @classmethod - def register_service(cls, name: str, type: type[object]): + def register_service(cls, interface_name: str, type: type[object]): + type_name = type.__name__ if not issubclass(type, ServiceInjector): - raise TypeError("Service %s is not a subclass of ServiceInjector and cannot be initialized." % name) - if name in cls.__SERVICES_REGISTRED: - already_class_name = cls.__SERVICES_REGISTRED[name].__name__ + raise TypeError("Service %s is not a subclass of ServiceInjector and cannot be initialized." % type_name) + if interface_name in cls.__SERVICES_REGISTRED: + already_class_name = cls.__SERVICES_REGISTRED[interface_name].__name__ raise ServiceAlreadyRegisterException( "Cannot register %s with %s, it is already registered with the class %s." - % (name, type.__name__, already_class_name) + % (interface_name, type_name, already_class_name) ) - cls.__SERVICES_REGISTRED[name] = type + cls.__SERVICES_REGISTRED[interface_name] = type @classmethod def create_services(cls): diff --git a/src/pyramid/api/tasks/tools/register.py b/src/pyramid/api/tasks/tools/register.py index 44d6bd3..aaa1e95 100644 --- a/src/pyramid/api/tasks/tools/register.py +++ b/src/pyramid/api/tasks/tools/register.py @@ -82,3 +82,4 @@ def running(loop: asyncio.AbstractEventLoop): parameters.thread.join() signal.signal(signal.SIGTERM, previous_handler) + logging.info("All tasks are stopped") diff --git a/src/cli.py b/src/pyramid/cli/startup.py similarity index 87% rename from src/cli.py rename to src/pyramid/cli/startup.py index 1330ce7..b2d83a2 100644 --- a/src/cli.py +++ b/src/pyramid/cli/startup.py @@ -1,9 +1,10 @@ import argparse +import sys from pyramid.api.services.information import IInformationService from pyramid.api.services.tools.tester import ServiceStandalone from pyramid.client.client import SocketClient -from pyramid.client.requests.health import HealthRequest +from pyramid.client.requests.ping import PingRequest def startup_cli(): ServiceStandalone.import_services() @@ -29,11 +30,10 @@ def startup_cli(): elif args.health: sc = SocketClient(args.host, args.port) - health = HealthRequest() - sc.send(health) + health = PingRequest() + result = sc.send(health) + if result is not True: + sys.exit(1) else: parser.print_help() - -if __name__ == "__main__": - startup_cli() diff --git a/src/pyramid/client/client.py b/src/pyramid/client/client.py index 32be670..81e0c48 100644 --- a/src/pyramid/client/client.py +++ b/src/pyramid/client/client.py @@ -27,7 +27,7 @@ def send(self, req: ARequest) -> bool: try: # Convert the request data to JSON - json_request = self.__common.serialize(req.ask) + json_request = SocketCommon.serialize(req.ask) # Connect to the server self.__logger.info("Connect to %s:%d", self.__host, self.__port) @@ -35,17 +35,16 @@ def send(self, req: ARequest) -> bool: # Send the JSON request to the server self.__logger.debug("Send '%s'", json_request) - self.__common.send_chunk(client_socket, json_request) + SocketCommon.send_chunk(client_socket, json_request) # Receive the response from the server response_str = self.__common.receive_chunk(client_socket) if not response_str: self.__logger.warning("Received empty request") return False - self.__logger.debug("Received '%s'", response_str) - self.receive(req, response_str) - return True + result = self.receive(req, response_str) + return result except OverflowError: self.__logger.warning( @@ -67,8 +66,8 @@ def send(self, req: ARequest) -> bool: client_socket.close() return False - def receive(self, action: ARequest, response_str: str): - response: SocketResponse = SocketResponse.from_json(self.__common.deserialize(response_str)) + def receive(self, action: ARequest, response_str: str) -> bool: + response: SocketResponse = SocketResponse.from_str(response_str) if not response.header: raise ValueError("No header received") @@ -76,4 +75,5 @@ def receive(self, action: ARequest, response_str: str): raise ValueError("No data received") response_data = action.load_data(**(self.__common.deserialize(response.data))) - action.client_receive(response.header, response_data) + result = action.client_receive(response.header, response_data) + return result diff --git a/src/pyramid/client/common.py b/src/pyramid/client/common.py index ddf3e2e..72ce430 100644 --- a/src/pyramid/client/common.py +++ b/src/pyramid/client/common.py @@ -28,14 +28,16 @@ def receive_chunk(self, client_socket: sock): return None return received_data.decode("utf-8") - def send_chunk(self, client_socket: sock, response: str): + @classmethod + def send_chunk(cls, client_socket: sock, response: str): # for chunk in [ # response[i : i + self.buffer_size] for i in range(0, len(response), self.buffer_size) # ]: # client_socket.send(chunk.encode("utf-8")) client_socket.send(response.encode("utf-8")) - def serialize(self, obj): + @classmethod + def serialize(cls, obj): def default(obj): if hasattr(obj, "__dict__"): # if isinstance(obj, SocketResponse): @@ -46,7 +48,8 @@ def default(obj): return json.dumps(obj, default=default) - def deserialize(self, obj: str, object_hook=None): + @classmethod + def deserialize(cls, obj: str, object_hook=None): return json.loads(obj, object_hook=object_hook) diff --git a/src/pyramid/client/requests/a_request.py b/src/pyramid/client/requests/a_request.py index aa393bf..6ad9b1b 100644 --- a/src/pyramid/client/requests/a_request.py +++ b/src/pyramid/client/requests/a_request.py @@ -2,9 +2,9 @@ from typing import Any # from pyramid.client.common import SocketHeader -from pyramid.client.responses.a_response import ResponseHeader from pyramid.client.a_socket import ASocket from pyramid.client.requests.ask_request import AskRequest +from pyramid.client.responses.a_response_header import ResponseHeader class ARequest(ASocket): @@ -17,5 +17,5 @@ def load_data(self, data) -> Any: pass @abstractmethod - def client_receive(self, header: ResponseHeader, data: Any): + def client_receive(self, header: ResponseHeader, data: Any) -> bool: pass diff --git a/src/pyramid/client/requests/health.py b/src/pyramid/client/requests/health.py deleted file mode 100644 index ba82f1d..0000000 --- a/src/pyramid/client/requests/health.py +++ /dev/null @@ -1,27 +0,0 @@ -import json -import logging -import sys -from pyramid.client.requests.a_request import ARequest -from pyramid.client.requests.ask_request import AskRequest -from pyramid.client.responses.a_response import ResponseHeader -from pyramid.data.health import HealthModules - - -class HealthRequest(ARequest): - def __init__(self) -> None: - super().__init__(AskRequest("health")) - - def load_data(self, **data) -> HealthModules: - return HealthModules(**data) - - def client_receive(self, header: ResponseHeader, data: HealthModules): - data_json = json.dumps(data.__dict__, indent=4) - - if not data.is_ok(): - logging.warn("Health check failed") - print(data_json) - sys.exit(1) - else: - logging.info("Health check valid") - print(data_json) - sys.exit(0) diff --git a/src/pyramid/client/requests/ping.py b/src/pyramid/client/requests/ping.py new file mode 100644 index 0000000..0c47434 --- /dev/null +++ b/src/pyramid/client/requests/ping.py @@ -0,0 +1,26 @@ +import json +import logging +import sys +from pyramid.client.requests.a_request import ARequest +from pyramid.client.requests.ask_request import AskRequest +from pyramid.client.responses.a_response_header import ResponseHeader +from pyramid.data.ping import PingSocket + + +class PingRequest(ARequest): + def __init__(self) -> None: + super().__init__(AskRequest("health")) + + def load_data(self, **data) -> PingSocket: + return PingSocket(**data) + + def client_receive(self, header: ResponseHeader, data: PingSocket) -> bool: + data_json = json.dumps(data.__dict__, indent=4) + + if not data.is_ok(): + logging.warning("Health check failed") + print(data_json) + return False + logging.info("Health check valid") + print(data_json) + return True diff --git a/src/pyramid/client/responses/a_response.py b/src/pyramid/client/responses/a_response.py index 9bf40f7..48ee1e0 100644 --- a/src/pyramid/client/responses/a_response.py +++ b/src/pyramid/client/responses/a_response.py @@ -1,17 +1,10 @@ from typing import Any, Self from pyramid.client.a_socket import ASocket -from pyramid.client.common import ResponseCode +from pyramid.client.common import ResponseCode, SocketCommon +from pyramid.client.responses.a_response_header import ResponseHeader # from pyramid.client.common import ResponseCode, SocketHeader -# class ReponseHeader(SocketHeader): -class ResponseHeader: - def __init__(self, code: ResponseCode, message: str | None) -> None: - # super().__init__(self.__class__) - self.code = code - self.message = message - - class SocketResponse(ASocket): def __init__( self, @@ -45,6 +38,12 @@ def to_json(self, serializer): "error_data": serializer(self.error_data) if self.error_data else None, } + @classmethod + def from_str(cls, data: str) -> Self: + json = SocketCommon.deserialize(data) + self = cls.from_json(json) + return self + @classmethod def from_json(cls, json_dict: dict) -> Self: self: Self = cls() diff --git a/src/pyramid/client/responses/a_response_header.py b/src/pyramid/client/responses/a_response_header.py new file mode 100644 index 0000000..9c6910d --- /dev/null +++ b/src/pyramid/client/responses/a_response_header.py @@ -0,0 +1,12 @@ +from typing import Any, Self +from pyramid.client.a_socket import ASocket +from pyramid.client.common import ResponseCode, SocketCommon +# from pyramid.client.common import ResponseCode, SocketHeader + + +# class ReponseHeader(SocketHeader): +class ResponseHeader: + def __init__(self, code: ResponseCode, message: str | None) -> None: + # super().__init__(self.__class__) + self.code = code + self.message = message diff --git a/src/pyramid/client/server.py b/src/pyramid/client/server.py deleted file mode 100644 index 7b73eca..0000000 --- a/src/pyramid/client/server.py +++ /dev/null @@ -1,98 +0,0 @@ -import logging -import socket -from logging import Logger -from socket import socket as sock -from typing import Any - -from pyramid.client.common import ResponseCode, SocketCommon -from pyramid.client.requests.ask_request import AskRequest -from pyramid.client.responses.a_response import SocketResponse -from pyramid.data.health import HealthModules - -# from _socket import _RetAddress - - -class SocketServer: - def __init__(self, logger: Logger, health: HealthModules, host: str = "0.0.0.0") -> None: - self.__common = SocketCommon() - self.__host = host - self.__port = self.__common.port - self.__health = health - if logger: - self.__logger = logger - else: - logger = logging.getLogger("socket") - - def start_server(self): - # Set the host and port for the socket server - - # Create a socket object - server_socket = sock(socket.AF_INET, socket.SOCK_STREAM) - # Bind the socket to a specific address and port - server_socket.bind((self.__host, self.__port)) - # Listen for incoming connections (up to x connections in the queue) - server_socket.listen(10) - - self.__logger.info("Socket server open on %s:%d", self.__host, self.__port) - - client_socket: sock - client_address: Any - while True: - # Wait for a connection from a client - client_socket, client_address = server_socket.accept() - client_ip = client_address[0] - client_port = client_address[1] - # self.__logger.debug("[%s:%d] accepted", client_ip, client_port) - - try: - response_to_send = self.handle_client(client_socket, client_ip, client_port) - - if response_to_send: - # Convert the response data to JSON - response_json = self.__common.serialize( - response_to_send.to_json(self.__common.serialize) - ) - - # Send the JSON response back to the client - # self.__logger.debug("[%s:%d] <- %s", client_ip, client_port, response_json) - self.__common.send_chunk(client_socket, response_json) - except Exception as err: - self.__logger.warning("[%s:%d] %s", client_ip, client_port, err, exc_info=True) - finally: - client_socket.close() - - def handle_client(self, client_socket: sock, client_ip, client_port) -> SocketResponse | None: - # Receive data from the client - data = self.__common.receive_chunk(client_socket) - - if not data: - self.__logger.info("[%s:%d] -> ", client_ip, client_port) - return - - # self.__logger.debug("[%s:%d] -> %s", client_ip, client_port, data) - - def object_hook(json): - if isinstance(json, dict): - return AskRequest(**json) - return json - - json_data: AskRequest = self.__common.deserialize(data, object_hook=object_hook) - - response = SocketResponse() - - # Check the content of the JSON data - if not json_data.action: - response.create(ResponseCode.ERROR, "Missing action field in JSON data") - return response - - if json_data.action == "health": - data = self.__health - response.create(ResponseCode.OK, None, data) - return response - - # If the action is unknown, respond with an error message - response.create(ResponseCode.ERROR, "Unknown action") - self.__logger.info( - "[%s:%d] <- Unknown action '%s'", client_ip, client_port, json_data.action - ) - return response diff --git a/src/pyramid/data/functional/main.py b/src/pyramid/data/functional/main.py index d6af0e0..d5a48a4 100644 --- a/src/pyramid/data/functional/main.py +++ b/src/pyramid/data/functional/main.py @@ -35,10 +35,6 @@ def start(self): ServiceRegister.inject_services() ServiceRegister.start_services() - logger = ServiceRegister.get_service(ILoggerService) - - logger.debug(ServiceRegister.get_dependency_tree()) - MainQueue.init() TaskRegister.import_tasks() diff --git a/src/pyramid/data/health.py b/src/pyramid/data/health.py deleted file mode 100644 index 0990056..0000000 --- a/src/pyramid/data/health.py +++ /dev/null @@ -1,7 +0,0 @@ -class HealthModules: - def __init__(self, configuration: bool = False, discord: bool = False) -> None: - self.configuration = configuration - self.discord = discord - - def is_ok(self) -> bool: - return self.configuration and self.discord diff --git a/src/pyramid/data/ping.py b/src/pyramid/data/ping.py new file mode 100644 index 0000000..175a493 --- /dev/null +++ b/src/pyramid/data/ping.py @@ -0,0 +1,6 @@ +class PingSocket: + def __init__(self, ok: bool): + self.ok = ok + + def is_ok(self) -> bool: + return self.ok diff --git a/src/pyramid/services/configuration.py b/src/pyramid/services/configuration.py index 9435119..4a96f7f 100644 --- a/src/pyramid/services/configuration.py +++ b/src/pyramid/services/configuration.py @@ -29,7 +29,7 @@ def load(self, config_file: str = "config.yml", use_env_vars: bool = True) -> bo raw_values_env, result_values, "env vars", True, keys_length ) - # Load raw values from environment variables and config file + # Load raw values from config file try: raw_values_file = self._get_file_vars(config_file) result_values = self._transform_all(raw_values_file, keys_length) diff --git a/src/pyramid/services/logger.py b/src/pyramid/services/logger.py index adcded4..32f5eb0 100644 --- a/src/pyramid/services/logger.py +++ b/src/pyramid/services/logger.py @@ -45,6 +45,10 @@ def start(self): def critical(self, msg, *args, **kwargs): self.logger.critical(msg, *args, **kwargs) + sys.exit(1) + + def exception(self, msg, *args, **kwargs): + self.logger.exception(msg, *args, **kwargs) def error(self, msg, *args, **kwargs): self.logger.error(msg, *args, **kwargs) diff --git a/src/pyramid/services/socket_server.py b/src/pyramid/services/socket_server.py index 5adf270..a559760 100644 --- a/src/pyramid/services/socket_server.py +++ b/src/pyramid/services/socket_server.py @@ -1,27 +1,113 @@ +import asyncio +import socket +from socket import socket as sock +from typing import Any -from threading import Thread - -from pyramid.api.services import ILoggerService, ISocketServerService +from pyramid.api.services.logger import ILoggerService +from pyramid.api.services.socket_server import ISocketServerService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector -from pyramid.client.server import SocketServer -from pyramid.data.health import HealthModules - +from pyramid.client.common import ResponseCode, SocketCommon +from pyramid.client.requests.ask_request import AskRequest +from pyramid.client.responses.a_response import SocketResponse +from pyramid.data.ping import PingSocket @pyramid_service(interface=ISocketServerService) class SocketServerService(ISocketServerService, ServiceInjector): - def __init__(self): - pass + def __init__(self) -> None: + self.__common = SocketCommon() + self.__host = "0.0.0.0" + self.__port = self.__common.port + self.is_running = False + self.server_socket: sock | None = None + + def injectService(self, + logger_service: ILoggerService + ): + self.__logger = logger_service + + async def open(self): + self.server_socket = sock(socket.AF_INET, socket.SOCK_STREAM) + # self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.server_socket.bind((self.__host, self.__port)) + self.server_socket.listen(10) + + self.__logger.info("Socket server open on %s:%d", self.__host, self.__port) + + self.is_running = True + client_socket: sock | None = None + client_address: Any = None + client_ip: Any = None + client_port: Any = None + + while self.is_running: + try: + client_socket, client_address = self.server_socket.accept() + client_ip = client_address[0] + client_port = client_address[1] + response_to_send = await self.__handle_client(client_socket, client_ip, client_port) + if response_to_send: + # Convert the response data to JSON + response_json = SocketCommon.serialize( + response_to_send.to_json(SocketCommon.serialize) + ) + + # Send the JSON response back to the client + # self.__logger.debug("[%s:%d] <- %s", client_ip, client_port, response_json) + self.__common.send_chunk(client_socket, response_json) + except Exception as err: + if isinstance(err, OSError): + if err.errno == 9: + self.__logger.warning("Socket: [Errno 9] Bad file descriptor") + continue + if client_ip is not None and client_port is not None: + self.__logger.warning("[%s:%d] %s", client_ip, client_port, err, exc_info=True) + finally: + if client_socket is not None: + client_socket.close() + client_socket = None + client_address = None + client_ip = None + client_port = None + self.__logger.info("Socket server closed") + + def close(self): + self.is_running = False + if self.server_socket is None: + return + self.server_socket.shutdown(socket.SHUT_RDWR) + self.server_socket.close() + self.server_socket = None + self.__logger.info("Socket server stop") + + async def __handle_client(self, client_socket: sock, client_ip, client_port) -> SocketResponse | None: + data = self.__common.receive_chunk(client_socket) + + if not data: + self.__logger.info("[%s:%d] -> ", client_ip, client_port) + return + + def object_hook(json): + if isinstance(json, dict): + return AskRequest(**json) + return json + + json_data: AskRequest = SocketCommon.deserialize(data, object_hook=object_hook) + + response = SocketResponse() - def injectService(self, logger_service: ILoggerService): - self.logger_service = logger_service + if not json_data.action: + response.create(ResponseCode.ERROR, "Missing action field in JSON data") + return response - def start(self): - self._health = HealthModules() - self._health.configuration = True - self._health.discord = True + if json_data.action == "health": + data = PingSocket(True) + response.create(ResponseCode.OK, None, data) + return response - self.socket_server = SocketServer(self.logger_service.getChild("socket"), self._health) - thread = Thread(name="Socket", target=self.socket_server.start_server, daemon=True) - thread.start() + response.create(ResponseCode.ERROR, "Unknown action") + self.__logger.info( + "[%s:%d] <- Unknown action '%s'", client_ip, client_port, json_data.action + ) + return response diff --git a/src/pyramid/tasks/discord.py b/src/pyramid/tasks/discord.py index 6f22e66..15c4a3a 100644 --- a/src/pyramid/tasks/discord.py +++ b/src/pyramid/tasks/discord.py @@ -6,12 +6,7 @@ @pyramid_task(parameters=ParametersTask()) class DiscordTask(TaskInjector): - def __init__(self): - pass - - def injectService(self, - discord_service: IDiscordService - ): + def injectService(self, discord_service: IDiscordService): self.__discord_service = discord_service async def worker_asyc(self): diff --git a/src/pyramid/tasks/socket_server.py b/src/pyramid/tasks/socket_server.py new file mode 100644 index 0000000..e5ee8c0 --- /dev/null +++ b/src/pyramid/tasks/socket_server.py @@ -0,0 +1,17 @@ +from pyramid.api.services.socket_server import ISocketServerService +from pyramid.api.tasks.tools.annotation import pyramid_task +from pyramid.api.tasks.tools.injector import TaskInjector +from pyramid.api.tasks.tools.parameters import ParametersTask + + +@pyramid_task(parameters=ParametersTask()) +class SocketServerTask(TaskInjector): + + def injectService(self, socket_service: ISocketServerService): + self.__socket_server = socket_service + + async def worker_asyc(self): + await self.__socket_server.open() + + async def stop_asyc(self): + self.__socket_server.close() diff --git a/src/pyramid/tools/custom_queue.py b/src/pyramid/tools/custom_queue.py index 70e435e..8eaa543 100644 --- a/src/pyramid/tools/custom_queue.py +++ b/src/pyramid/tools/custom_queue.py @@ -101,7 +101,7 @@ def __init__(self, threads=1, name=None): def create_threads(self): for thread_id in range(1, self.__threads + 1): thread = Thread( - name="%s n°%d{thread_id}" % (self.__name, thread_id), + name="%s n°%d" % (self.__name, thread_id), target=self.__worker, args=(self.__queue, thread_id, self.__lock, self.__event), daemon=True, diff --git a/src/startup_cli.py b/src/startup_cli.py new file mode 100644 index 0000000..174ccc6 --- /dev/null +++ b/src/startup_cli.py @@ -0,0 +1,4 @@ +from pyramid.cli.startup import startup_cli + +if __name__ == "__main__": + startup_cli() diff --git a/src/startup_cli_dev.py b/src/startup_cli_dev.py new file mode 100644 index 0000000..c15c04d --- /dev/null +++ b/src/startup_cli_dev.py @@ -0,0 +1,10 @@ +import debugpy +from pyramid.cli.startup import startup_cli + +def startup_cli_dev(): + debugpy.listen(('0.0.0.0', 5679)) + debugpy.wait_for_client() + startup_cli() + +if __name__ == "__main__": + startup_cli_dev() diff --git a/src/dev.py b/src/startup_dev.py similarity index 100% rename from src/dev.py rename to src/startup_dev.py From 5c55c4274960189cba3de634eff62bcfe53b0b3c Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Sun, 29 Sep 2024 23:17:27 +0200 Subject: [PATCH 16/32] fix: close tasks properly --- src/pyramid/api/tasks/tools/register.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyramid/api/tasks/tools/register.py b/src/pyramid/api/tasks/tools/register.py index aaa1e95..159997f 100644 --- a/src/pyramid/api/tasks/tools/register.py +++ b/src/pyramid/api/tasks/tools/register.py @@ -54,7 +54,7 @@ def __handle_signal(cls, signum: int, frame): for name, parameters in cls.__TASKS_REGISTERED.items(): async def shutdown(loop: asyncio.AbstractEventLoop): await parameters.task_cls.stop_asyc() - parameters.loop.stop() + loop.stop() asyncio.run_coroutine_threadsafe(shutdown(parameters.loop), parameters.loop) @classmethod From 70d1bc6f2542028fbdb964150e7291c6418d5a4d Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Mon, 30 Sep 2024 01:39:09 +0200 Subject: [PATCH 17/32] feat: all commands are now registred to use injection --- src/pyramid/api/services/tools/register.py | 2 +- src/pyramid/api/tasks/tools/parameters.py | 2 +- src/pyramid/api/tasks/tools/register.py | 12 +- .../connector/discord/commands/about.py | 21 ++- .../connector/discord/commands/goto.py | 29 +++ .../connector/discord/commands/next.py | 29 +++ .../connector/discord/commands/pause.py | 29 +++ .../connector/discord/commands/play.py | 30 +++ .../connector/discord/commands/play_next.py | 30 +++ .../connector/discord/commands/play_url.py | 30 +++ .../discord/commands/play_url_next.py | 29 +++ .../connector/discord/commands/queue.py | 29 +++ .../connector/discord/commands/remove.py | 29 +++ .../connector/discord/commands/resume.py | 29 +++ .../connector/discord/commands/search.py | 30 +++ .../connector/discord/commands/shuffle.py | 29 +++ .../connector/discord/commands/stop.py | 29 +++ .../connector/discord/commands/tools/abc.py | 23 +-- .../discord/commands/tools/exception.py | 5 + .../discord/commands/tools/parameters.py | 10 +- .../discord/commands/tools/register.py | 57 ++++-- src/pyramid/connector/discord/guild_cmd.py | 17 -- src/pyramid/data/functional/main.py | 2 +- src/pyramid/services/discord_commads.py | 177 +----------------- src/pyramid/services/source.py | 1 - src/pyramid/services/spotify_client.py | 4 +- src/startup_dev.py | 1 + 27 files changed, 470 insertions(+), 245 deletions(-) create mode 100644 src/pyramid/connector/discord/commands/goto.py create mode 100644 src/pyramid/connector/discord/commands/next.py create mode 100644 src/pyramid/connector/discord/commands/pause.py create mode 100644 src/pyramid/connector/discord/commands/play.py create mode 100644 src/pyramid/connector/discord/commands/play_next.py create mode 100644 src/pyramid/connector/discord/commands/play_url.py create mode 100644 src/pyramid/connector/discord/commands/play_url_next.py create mode 100644 src/pyramid/connector/discord/commands/queue.py create mode 100644 src/pyramid/connector/discord/commands/remove.py create mode 100644 src/pyramid/connector/discord/commands/resume.py create mode 100644 src/pyramid/connector/discord/commands/search.py create mode 100644 src/pyramid/connector/discord/commands/shuffle.py create mode 100644 src/pyramid/connector/discord/commands/stop.py create mode 100644 src/pyramid/connector/discord/commands/tools/exception.py diff --git a/src/pyramid/api/services/tools/register.py b/src/pyramid/api/services/tools/register.py index dbae533..0c4f98c 100644 --- a/src/pyramid/api/services/tools/register.py +++ b/src/pyramid/api/services/tools/register.py @@ -31,7 +31,7 @@ def register_service(cls, interface_name: str, type: type[object]): if interface_name in cls.__SERVICES_REGISTRED: already_class_name = cls.__SERVICES_REGISTRED[interface_name].__name__ raise ServiceAlreadyRegisterException( - "Cannot register %s with %s, it is already registered with the class %s." + "Cannot register service %s with %s, it is already registered with the class %s." % (interface_name, type_name, already_class_name) ) cls.__SERVICES_REGISTRED[interface_name] = type diff --git a/src/pyramid/api/tasks/tools/parameters.py b/src/pyramid/api/tasks/tools/parameters.py index a912a0c..e732421 100644 --- a/src/pyramid/api/tasks/tools/parameters.py +++ b/src/pyramid/api/tasks/tools/parameters.py @@ -9,4 +9,4 @@ class ParametersTask: def __init__(self): self.loop: asyncio.AbstractEventLoop self.thread: Thread - self.task_cls: TaskInjector + self.cls_instance: TaskInjector diff --git a/src/pyramid/api/tasks/tools/register.py b/src/pyramid/api/tasks/tools/register.py index 159997f..c551079 100644 --- a/src/pyramid/api/tasks/tools/register.py +++ b/src/pyramid/api/tasks/tools/register.py @@ -30,13 +30,13 @@ def import_tasks(cls): def register_tasks(cls, type: type[object], parameters: ParametersTask): if not issubclass(type, TaskInjector): raise TypeError("Service %s is not a subclass of TaskInjector and cannot be initialized." % type.__name__) - parameters.task_cls = type() + parameters.cls_instance = type() cls.__TASKS_REGISTERED[type.__name__] = parameters @classmethod def inject_tasks(cls): for name, parameters in cls.__TASKS_REGISTERED.items(): - signature = inspect.signature(parameters.task_cls.injectService) + signature = inspect.signature(parameters.cls_instance.injectService) method_parameters = list(signature.parameters.values()) services_dependencies = [] @@ -45,7 +45,7 @@ def inject_tasks(cls): dependency_instance = ServiceRegister.get_service(dependency_cls) services_dependencies.append(dependency_instance) - parameters.task_cls.injectService(*services_dependencies) + parameters.cls_instance.injectService(*services_dependencies) @classmethod def __handle_signal(cls, signum: int, frame): @@ -53,7 +53,7 @@ def __handle_signal(cls, signum: int, frame): for name, parameters in cls.__TASKS_REGISTERED.items(): async def shutdown(loop: asyncio.AbstractEventLoop): - await parameters.task_cls.stop_asyc() + await parameters.cls_instance.stop_asyc() loop.stop() asyncio.run_coroutine_threadsafe(shutdown(parameters.loop), parameters.loop) @@ -67,7 +67,7 @@ def start_tasks(cls): def running(loop: asyncio.AbstractEventLoop): asyncio.set_event_loop(loop) - loop.create_task(parameters.task_cls.worker_asyc()) + loop.create_task(parameters.cls_instance.worker_asyc()) try: loop.run_forever() finally: @@ -82,4 +82,4 @@ def running(loop: asyncio.AbstractEventLoop): parameters.thread.join() signal.signal(signal.SIGTERM, previous_handler) - logging.info("All tasks are stopped") + logging.info("All registered tasks are stopped") diff --git a/src/pyramid/connector/discord/commands/about.py b/src/pyramid/connector/discord/commands/about.py index 765cbe2..7b1bae7 100644 --- a/src/pyramid/connector/discord/commands/about.py +++ b/src/pyramid/connector/discord/commands/about.py @@ -2,6 +2,9 @@ import time from discord import AppInfo, ClientUser, Color, Embed, Interaction from discord.user import BaseUser +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.information import IInformationService +from pyramid.api.services.logger import ILoggerService from pyramid.connector.discord.commands.tools.abc import AbstractCommand from pyramid.connector.discord.commands.tools.annotation import discord_command from pyramid.connector.discord.commands.tools.parameters import ParametersCommand @@ -12,9 +15,14 @@ @discord_command(parameters=ParametersCommand(description="About the bot")) class AboutCommand(AbstractCommand): - def injectService(self, environment: Environment, info: ApplicationInfo): - self.__environment = environment - self.__info = info + def injectService(self, + information_service: IInformationService, + configuration_service: IConfigurationService, + logger_service: ILoggerService + ): + self.__information_service = information_service + self.logger = logger_service + self.__configuration_service = configuration_service async def execute(self, ctx: Interaction): await ctx.response.defer(thinking=True) @@ -25,7 +33,7 @@ async def execute(self, ctx: Interaction): bot_user = None self.logger.warning("Unable to get self user instance") - info = self.__info + info = self.__information_service.get() embed = Embed(title=info.get_name(), color=Color.gold()) if bot_user is not None and bot_user.avatar is not None: embed.set_thumbnail(url=bot_user.avatar.url) @@ -54,17 +62,18 @@ async def execute(self, ctx: Interaction): text=f"Owned by {owner.display_name}", icon_url=owner.avatar.url if owner.avatar is not None else None, ) + environnement = self.__configuration_service.mode.name.capitalize() embed.add_field(name="Version", value=info.get_version(), inline=True) embed.add_field(name="OS", value=info.get_os(), inline=True) embed.add_field( name="Environment", - value=self.__environment.name.capitalize(), + value=environnement, inline=True, ) embed.add_field( name="Uptime", - value=utils.time_to_duration(int(round(time.time() - self.__info.get_started_at()))), + value=utils.time_to_duration(int(round(time.time() - info.get_started_at()))), inline=True, ) diff --git a/src/pyramid/connector/discord/commands/goto.py b/src/pyramid/connector/discord/commands/goto.py new file mode 100644 index 0000000..d884b38 --- /dev/null +++ b/src/pyramid/connector/discord/commands/goto.py @@ -0,0 +1,29 @@ +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued + + +@discord_command(parameters=ParametersCommand( + description="Jumps to a specific track in the queue", + only_guild=True +)) +class GotoCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction, number_in_queue: int): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + await guild_cmd.goto(ms, ctx, number_in_queue) diff --git a/src/pyramid/connector/discord/commands/next.py b/src/pyramid/connector/discord/commands/next.py new file mode 100644 index 0000000..c274691 --- /dev/null +++ b/src/pyramid/connector/discord/commands/next.py @@ -0,0 +1,29 @@ +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued + + +@discord_command(parameters=ParametersCommand( + description="Skips to the next track", + only_guild=True +)) +class NextCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + await guild_cmd.next(ms, ctx) diff --git a/src/pyramid/connector/discord/commands/pause.py b/src/pyramid/connector/discord/commands/pause.py new file mode 100644 index 0000000..c6a115d --- /dev/null +++ b/src/pyramid/connector/discord/commands/pause.py @@ -0,0 +1,29 @@ +from discord import Guild, Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued + + +@discord_command(parameters=ParametersCommand( + description="Pauses the music", + only_guild=True +)) +class PauseCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + await guild_cmd.pause(ms, ctx) diff --git a/src/pyramid/connector/discord/commands/play.py b/src/pyramid/connector/discord/commands/play.py new file mode 100644 index 0000000..76260b0 --- /dev/null +++ b/src/pyramid/connector/discord/commands/play.py @@ -0,0 +1,30 @@ +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued +from pyramid.data.source_type import SourceType + + +@discord_command(parameters=ParametersCommand( + description="Adds a track to the end of the queue and plays it", + only_guild=True +)) +class PlayCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction, input: str, engine: SourceType | None): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + await guild_cmd.play(ms, ctx, input, engine) diff --git a/src/pyramid/connector/discord/commands/play_next.py b/src/pyramid/connector/discord/commands/play_next.py new file mode 100644 index 0000000..7939e63 --- /dev/null +++ b/src/pyramid/connector/discord/commands/play_next.py @@ -0,0 +1,30 @@ +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued +from pyramid.data.source_type import SourceType + + +@discord_command(parameters=ParametersCommand( + description="Plays a track next the current one", + only_guild=True +)) +class PlayNextCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction, input: str, engine: SourceType | None): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + await guild_cmd.play(ms, ctx, input, engine, at_end=False) diff --git a/src/pyramid/connector/discord/commands/play_url.py b/src/pyramid/connector/discord/commands/play_url.py new file mode 100644 index 0000000..e574b1d --- /dev/null +++ b/src/pyramid/connector/discord/commands/play_url.py @@ -0,0 +1,30 @@ +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued +from pyramid.data.source_type import SourceType + + +@discord_command(parameters=ParametersCommand( + description="Plays a track, artist, album, or playlist from a URL", + only_guild=True +)) +class PlayUrlCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction, url: str): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + await guild_cmd.play_url(ms, ctx, url) diff --git a/src/pyramid/connector/discord/commands/play_url_next.py b/src/pyramid/connector/discord/commands/play_url_next.py new file mode 100644 index 0000000..e9e2f47 --- /dev/null +++ b/src/pyramid/connector/discord/commands/play_url_next.py @@ -0,0 +1,29 @@ +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued + + +@discord_command(parameters=ParametersCommand( + description="Plays a track, artist, album, or playlist from a URL next in the queue", + only_guild=True +)) +class PlayNextUrlCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction, url: str): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + await guild_cmd.play_url(ms, ctx, url, at_end=False) diff --git a/src/pyramid/connector/discord/commands/queue.py b/src/pyramid/connector/discord/commands/queue.py new file mode 100644 index 0000000..22c77c4 --- /dev/null +++ b/src/pyramid/connector/discord/commands/queue.py @@ -0,0 +1,29 @@ +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued + + +@discord_command(parameters=ParametersCommand( + description="Displays the current track queue", + only_guild=True +)) +class QueueCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + guild_cmd.queue_list(ms, ctx) diff --git a/src/pyramid/connector/discord/commands/remove.py b/src/pyramid/connector/discord/commands/remove.py new file mode 100644 index 0000000..bc83f13 --- /dev/null +++ b/src/pyramid/connector/discord/commands/remove.py @@ -0,0 +1,29 @@ +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued + + +@discord_command(parameters=ParametersCommand( + description="Removes a track from the queue", + only_guild=True +)) +class RemoveCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction, number_in_queue: int): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + await guild_cmd.remove(ms, ctx, number_in_queue) diff --git a/src/pyramid/connector/discord/commands/resume.py b/src/pyramid/connector/discord/commands/resume.py new file mode 100644 index 0000000..33fa336 --- /dev/null +++ b/src/pyramid/connector/discord/commands/resume.py @@ -0,0 +1,29 @@ +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued + + +@discord_command(parameters=ParametersCommand( + description="Resumes the paused music", + only_guild=True +)) +class ResumeCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + await guild_cmd.resume(ms, ctx) diff --git a/src/pyramid/connector/discord/commands/search.py b/src/pyramid/connector/discord/commands/search.py new file mode 100644 index 0000000..99255ba --- /dev/null +++ b/src/pyramid/connector/discord/commands/search.py @@ -0,0 +1,30 @@ +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued +from pyramid.data.source_type import SourceType + + +@discord_command(parameters=ParametersCommand( + description="Searches for tracks", + only_guild=True +)) +class SearchCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction, input: str, engine: SourceType | None): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + await guild_cmd.search(ms, input, engine) diff --git a/src/pyramid/connector/discord/commands/shuffle.py b/src/pyramid/connector/discord/commands/shuffle.py new file mode 100644 index 0000000..d78c602 --- /dev/null +++ b/src/pyramid/connector/discord/commands/shuffle.py @@ -0,0 +1,29 @@ +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued + + +@discord_command(parameters=ParametersCommand( + description="Randomizes the track queue", + only_guild=True +)) +class ShuffleCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + await guild_cmd.shuffle(ms, ctx) diff --git a/src/pyramid/connector/discord/commands/stop.py b/src/pyramid/connector/discord/commands/stop.py new file mode 100644 index 0000000..888e8c3 --- /dev/null +++ b/src/pyramid/connector/discord/commands/stop.py @@ -0,0 +1,29 @@ +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued + + +@discord_command(parameters=ParametersCommand( + description="Stops the music and leaves the channel", + only_guild=True +)) +class StopCommand(AbstractCommand): + + def injectService(self, + discord_service: IDiscordService + ): + self.__discord_service = discord_service + + async def execute(self, ctx: Interaction): + ms = MessageSenderQueued(ctx) + await ms.thinking() + guild = ctx.guild + if not guild: + raise Exception("Command was not executed in a guild") + guild_cmd = self.__discord_service.get_guild_cmd(guild) + + await guild_cmd.stop(ms, ctx) diff --git a/src/pyramid/connector/discord/commands/tools/abc.py b/src/pyramid/connector/discord/commands/tools/abc.py index f1f4bd0..0c1767a 100644 --- a/src/pyramid/connector/discord/commands/tools/abc.py +++ b/src/pyramid/connector/discord/commands/tools/abc.py @@ -1,35 +1,14 @@ -import logging from abc import ABC, abstractmethod -from typing import Optional from discord import Interaction -from discord.app_commands import Command from discord.ext.commands import Bot from pyramid.connector.discord.commands.tools.parameters import ParametersCommand class AbstractCommand(ABC): - def __init__(self, parameters: ParametersCommand, bot: Bot, logger: logging.Logger): + def __init__(self, parameters: ParametersCommand, bot: Bot): self.parameters = parameters self.bot = bot - self.logger = logger - - def register(self, command_prefix: Optional[str] = None): - if command_prefix is not None: - self.parameters.name = "%s_%s" % (command_prefix, self.parameters.name) - - command = Command( - name=self.parameters.name, - description=self.parameters.description, - callback=self.execute, - nsfw=self.parameters.nsfw, - parent=None, - auto_locale_strings=self.parameters.auto_locale_strings, - extras=self.parameters.extras, - ) - # TODO check this usage - # self.bot.tree.add_command(command, guilds=self.parameters.guilds) - self.bot.tree.add_command(command) def injectService(self): pass diff --git a/src/pyramid/connector/discord/commands/tools/exception.py b/src/pyramid/connector/discord/commands/tools/exception.py new file mode 100644 index 0000000..039b9e0 --- /dev/null +++ b/src/pyramid/connector/discord/commands/tools/exception.py @@ -0,0 +1,5 @@ +class CommandAlreadyRegisterException(Exception): + pass + +class CommandNameAlreadyRegisterException(Exception): + pass diff --git a/src/pyramid/connector/discord/commands/tools/parameters.py b/src/pyramid/connector/discord/commands/tools/parameters.py index 0d1f2f0..3f0fe56 100644 --- a/src/pyramid/connector/discord/commands/tools/parameters.py +++ b/src/pyramid/connector/discord/commands/tools/parameters.py @@ -6,13 +6,14 @@ class ParametersCommand: def __init__(self, - name: Union[str, locale_str] = MISSING, - description: Union[str, locale_str] = MISSING, + name: str = MISSING, + description: str | locale_str = MISSING, nsfw: bool = False, guild: Optional[Snowflake] = MISSING, guilds: Sequence[Snowflake] = MISSING, auto_locale_strings: bool = True, - extras: Dict[Any, Any] = MISSING + extras: Dict[Any, Any] = MISSING, + only_guild = False ): self.name = name self.description = description @@ -20,4 +21,5 @@ def __init__(self, self.guild = guild self.guilds = guilds self.auto_locale_strings = auto_locale_strings - self.extras: Dict[Any, Any] = MISSING + self.extras = extras + self.only_guild = only_guild diff --git a/src/pyramid/connector/discord/commands/tools/register.py b/src/pyramid/connector/discord/commands/tools/register.py index a316eb7..5a488cc 100644 --- a/src/pyramid/connector/discord/commands/tools/register.py +++ b/src/pyramid/connector/discord/commands/tools/register.py @@ -2,18 +2,27 @@ import inspect import logging import pkgutil +from typing import Optional from discord.ext.commands import Bot +from discord.app_commands import Command +from discord.app_commands.installs import AppCommandContext from pyramid.api.services.tools.register import ServiceRegister from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.exception import CommandAlreadyRegisterException, CommandNameAlreadyRegisterException from pyramid.connector.discord.commands.tools.parameters import ParametersCommand class CommandRegister: - __COMMANDS_TO_REGISTER: dict[type[AbstractCommand], ParametersCommand] = {} + __COMMANDS_REGISTERED: dict[type[AbstractCommand], ParametersCommand] = {} + __COMMANDS_INSTANCE: dict[str, AbstractCommand] = {} @classmethod - def register_command(cls, type: type[AbstractCommand], parameterCommand: ParametersCommand): - CommandRegister.__COMMANDS_TO_REGISTER[type] = parameterCommand + def register_command(cls, type: type[AbstractCommand], parameters: ParametersCommand): + if type in cls.__COMMANDS_REGISTERED: + raise CommandAlreadyRegisterException( + "Cannot register command %s it is already registered." % (type.__name__) + ) + cls.__COMMANDS_REGISTERED[type] = parameters @classmethod def import_commands(cls): @@ -25,15 +34,41 @@ def import_commands(cls): importlib.import_module(full_module_name) @classmethod - def create_commands(cls, bot: Bot, logger: logging.Logger, command_prefix: str | None = None): - for type, parameters in cls.__COMMANDS_TO_REGISTER.items(): - class_instance = type(parameters, bot, logger) - class_instance.register(command_prefix) + def create_commands(cls, bot: Bot, command_prefix: str | None = None): + for type, parameters in cls.__COMMANDS_REGISTERED.items(): + cls_instance = type(parameters, bot) + if command_prefix is not None: + cls_instance.parameters.name = "%s_%s" % (command_prefix, cls_instance.parameters.name) + if cls_instance.parameters.name in cls.__COMMANDS_INSTANCE: + cls_already_instance = cls.__COMMANDS_INSTANCE[cls_instance.parameters.name] + raise CommandNameAlreadyRegisterException( + "Cannot register command %s with %s, it is already registered with the class %s." + % ( cls_instance.parameters.name, type.__name__, cls_already_instance.__class__.__name__) + ) + + allowed_contexts: Optional[AppCommandContext] = None + if cls_instance.parameters.only_guild is True: + allowed_contexts = AppCommandContext(guild=True) + + discord_command = Command( + name=cls_instance.parameters.name, + description=cls_instance.parameters.description, + callback=cls_instance.execute, + nsfw=cls_instance.parameters.nsfw, + parent=None, + auto_locale_strings=cls_instance.parameters.auto_locale_strings, + extras=cls_instance.parameters.extras, + allowed_contexts=allowed_contexts + ) + # TODO check this usage + # self.bot.tree.add_command(command, guilds=command.parameters.guilds) + cls_instance.bot.tree.add_command(discord_command) + cls.__COMMANDS_INSTANCE[cls_instance.parameters.name] = cls_instance @classmethod - def inject_tasks(cls): - for type, parameters in cls.__COMMANDS_TO_REGISTER.items(): - signature = inspect.signature(type.injectService) + def inject_commands(cls): + for type, cls_instance in cls.__COMMANDS_INSTANCE.items(): + signature = inspect.signature(cls_instance.injectService) method_parameters = list(signature.parameters.values()) services_dependencies = [] @@ -42,4 +77,4 @@ def inject_tasks(cls): dependency_instance = ServiceRegister.get_service(dependency_cls) services_dependencies.append(dependency_instance) - type.injectService(*services_dependencies) + cls_instance.injectService(*services_dependencies) diff --git a/src/pyramid/connector/discord/guild_cmd.py b/src/pyramid/connector/discord/guild_cmd.py index 3fcabfa..c053bb5 100644 --- a/src/pyramid/connector/discord/guild_cmd.py +++ b/src/pyramid/connector/discord/guild_cmd.py @@ -52,7 +52,6 @@ async def play( return await self._execute_play(ms, voice_channel, track, at_end=at_end) async def stop(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: - ctx.channel voice_channel: VoiceChannel | None = await self._verify_voice_channel(ms, ctx.user, ms.txt_channel) if not voice_channel or not await self._verify_bot_channel(ms, voice_channel): return False @@ -197,22 +196,6 @@ def queue_list(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: ms.add_code_message(queue, prefix="Here's the music in the queue :") return True - async def searchV1( - self, ms: MessageSenderQueued, input: str, engine: SourceType | None = None - ) -> bool: - try: - tracks, tracks_unfindable = await self.data.search_engine.search_tracks(input, engine) - except DiscordMessageException as err: - ms.add_message(err.msg) - return False - - hsa = utils_list_track.to_str(tracks) - if tracks_unfindable: - hsa = utils_list_track.to_str(tracks_unfindable) - ms.add_code_message(hsa, prefix=":warning: Can't find the audio for these tracks :") - ms.add_code_message(hsa, prefix="Here are the results of your search :") - return True - async def search( self, ms: MessageSenderQueued, input: str, engine: SourceType | None = None ) -> bool: diff --git a/src/pyramid/data/functional/main.py b/src/pyramid/data/functional/main.py index d5a48a4..1294113 100644 --- a/src/pyramid/data/functional/main.py +++ b/src/pyramid/data/functional/main.py @@ -43,6 +43,6 @@ def start(self): def stop(self): - logging.info("Wait for background tasks to stop") + logging.info("Wait for others tasks to stop ...") Queue.wait_for_end(5) logging.info("Bye bye") diff --git a/src/pyramid/services/discord_commads.py b/src/pyramid/services/discord_commads.py index 2e2db6c..c28d954 100644 --- a/src/pyramid/services/discord_commads.py +++ b/src/pyramid/services/discord_commads.py @@ -22,177 +22,11 @@ def injectService(self, def start(self): bot = self.__discord_service.bot + command_prefix = self.__configuration_service.mode.name.lower() CommandRegister.import_commands() - CommandRegister.create_commands(bot, self.__logger.getChild("Discord"), self.__configuration_service.mode.name.lower()) - - # ping = PingCommand(ParametersCommand("ping"), self.__bot, self.__logger) - # ping.register(self.__environment.name.lower()) - - # about = AboutCommand(self.__bot, self.__logger, self.__started, self.__environment, self.__info) - # about.register(self.__environment.name.lower()) - - # help = HelpCommand(self.__bot, self.__logger) - # help.register(self.__environment.name.lower()) - - @bot.tree.command(name="play", description="Adds a track to the end of the queue and plays it") - async def cmd_play(ctx: Interaction, input: str, engine: SourceType | None): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - await guild_cmd.play(ms, ctx, input, engine) - - @bot.tree.command(name="play_next", description="Plays a track next the current one") - async def cmd_play_next(ctx: Interaction, input: str, engine: SourceType | None): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - await guild_cmd.play(ms, ctx, input, engine, at_end=False) - - @bot.tree.command(name="pause", description="Pauses the music") - async def cmd_pause(ctx: Interaction): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - await guild_cmd.pause(ms, ctx) - - @bot.tree.command(name="resume", description="Resumes the paused music") - async def cmd_resume(ctx: Interaction): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - await guild_cmd.resume(ms, ctx) - - @bot.tree.command(name="stop", description="Stops the music and leaves the channel") - async def cmd_stop(ctx: Interaction): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - await guild_cmd.stop(ms, ctx) - - @bot.tree.command(name="next", description="Skips to the next track") - async def cmd_next(ctx: Interaction): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - await guild_cmd.next(ms, ctx) - - @bot.tree.command(name="shuffle", description="Randomizes the track queue") - async def cmd_shuffle(ctx: Interaction): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - await guild_cmd.shuffle(ms, ctx) - - @bot.tree.command(name="remove", description="Removes a track from the queue") - async def cmd_remove(ctx: Interaction, number_in_queue: int): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - await guild_cmd.remove(ms, ctx, number_in_queue) - - @bot.tree.command(name="goto", description="Jumps to a specific track in the queue") - async def cmd_goto(ctx: Interaction, number_in_queue: int): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - await guild_cmd.goto(ms, ctx, number_in_queue) - - @bot.tree.command(name="queue", description="Displays the current track queue") - async def cmd_queue(ctx: Interaction): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - guild_cmd.queue_list(ms, ctx) - - # @bot.tree.command(name="search_v1", description="Search tracks (old way)") - # async def cmd_search_v1(ctx: Interaction, input: str, engine: SourceType | None): - # if (await self.__use_on_guild_only(ctx)) is False: - # return - # ms = MessageSenderQueued(ctx) - # await ms.thinking() - # guild: Guild = ctx.guild # type: ignore - # guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - # await guild_cmd.searchV1(ms, input, engine) - - @bot.tree.command(name="search", description="Searches for tracks") - async def cmd_search(ctx: Interaction, input: str, engine: SourceType | None): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - await guild_cmd.search(ms, input, engine) - - @bot.tree.command( - name="play_url", description="Plays a track, artist, album, or playlist from a URL" - ) - async def cmd_play_url(ctx: Interaction, url: str): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - await guild_cmd.play_url(ms, ctx, url) - - @bot.tree.command( - name="play_url_next", - description="Plays a track, artist, album, or playlist from a URL next in the queue", - ) - async def cmd_play_url_next(ctx: Interaction, url: str): - if (await self.__use_on_guild_only(ctx)) is False: - return - ms = MessageSenderQueued(ctx) - await ms.thinking() - guild: Guild = ctx.guild # type: ignore - guild_cmd: GuildCmd = self.__discord_service.get_guild_cmd(guild) - - await guild_cmd.play_url(ms, ctx, url, at_end=False) + CommandRegister.create_commands(bot, command_prefix) + CommandRegister.inject_commands() # @bot.tree.command(name="spam", description="Test spam") # async def cmd_spam(ctx: Interaction): @@ -203,8 +37,3 @@ async def cmd_play_url_next(ctx: Interaction, url: str): # ms.add_message(f"Spam n°{i}") # await ctx.response.send_message("Spam ended") - async def __use_on_guild_only(self, ctx: Interaction) -> bool: - if ctx.guild is None: - await ctx.response.send_message("You can use this command only on a guild") - return False - return True diff --git a/src/pyramid/services/source.py b/src/pyramid/services/source.py index 589dadd..6a94230 100644 --- a/src/pyramid/services/source.py +++ b/src/pyramid/services/source.py @@ -21,7 +21,6 @@ def injectService(self, downloader_service: IDeezerDownloaderService, deezer_search_service: IDeezerSearchService, spotify_search_service: ISpotifySearchService, - ): self.__downloader = downloader_service self.__deezer_search = deezer_search_service diff --git a/src/pyramid/services/spotify_client.py b/src/pyramid/services/spotify_client.py index 8047d29..e089c4c 100644 --- a/src/pyramid/services/spotify_client.py +++ b/src/pyramid/services/spotify_client.py @@ -7,6 +7,7 @@ from spotipy import Spotify from spotipy.exceptions import SpotifyException from spotipy.oauth2 import SpotifyClientCredentials +from spotipy.cache_handler import MemoryCacheHandler from pyramid.api.services.configuration import IConfigurationService from pyramid.api.services.spotify_client import ISpotifyClientService @@ -24,7 +25,8 @@ def injectService(self, def start(self): self.client_credentials_manager = SpotifyClientCredentials( client_id=self.__configuration_service.spotify__client_id, - client_secret=self.__configuration_service.spotify__client_secret + client_secret=self.__configuration_service.spotify__client_secret, + cache_handler=MemoryCacheHandler() ) self.auth_manager = None diff --git a/src/startup_dev.py b/src/startup_dev.py index bb5bf71..c70f017 100644 --- a/src/startup_dev.py +++ b/src/startup_dev.py @@ -5,6 +5,7 @@ def startup_dev(): debugpy.listen(('0.0.0.0', 5678)) + # debugpy.wait_for_client() startup() if __name__ == "__main__": From 513f42045f66226a4401082882f120dfba68794c Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Tue, 1 Oct 2024 02:25:59 +0200 Subject: [PATCH 18/32] fix: service start order --- src/pyramid/api/services/tools/register.py | 58 ++++++++++++------- src/pyramid/data/functional/main.py | 5 +- .../messages/message_sender_queued.py | 2 +- src/pyramid/tools/utils.py | 5 +- 4 files changed, 44 insertions(+), 26 deletions(-) diff --git a/src/pyramid/api/services/tools/register.py b/src/pyramid/api/services/tools/register.py index 0c4f98c..b90e934 100644 --- a/src/pyramid/api/services/tools/register.py +++ b/src/pyramid/api/services/tools/register.py @@ -13,9 +13,18 @@ class ServiceRegister: __SERVICES_REGISTRED: dict[str, type[ServiceInjector]] = {} __SERVICES_INSTANCES: dict[str, ServiceInjector] = {} + __ORDERED_SERVICES: list[str] | None = None @classmethod - def import_services(cls): + def enable(cls): + cls.__import_services() + cls.__create_services() + cls.__determine_service_order() + cls.__inject_services() + cls.__start_services() + + @classmethod + def __import_services(cls): package_name = "pyramid.services" package = importlib.import_module(package_name) @@ -37,18 +46,12 @@ def register_service(cls, interface_name: str, type: type[object]): cls.__SERVICES_REGISTRED[interface_name] = type @classmethod - def create_services(cls): - for name, service_type in cls.__SERVICES_REGISTRED.items(): - class_instance = service_type() - cls.__SERVICES_INSTANCES[name] = class_instance - - @classmethod - def inject_services(cls): + def __determine_service_order(cls): # Step 1: Create a graph of dependencies dependency_graph = defaultdict(list) indegree = defaultdict(int) # To track the number of dependencies - # Create instances but delay injecting dependencies + # Parse dependencies but delay injecting for name, service_type in cls.__SERVICES_REGISTRED.items(): class_instance = cls.__SERVICES_INSTANCES[name] @@ -60,8 +63,7 @@ def inject_services(cls): dependency_name = method_parameter.annotation.__name__ if dependency_name not in cls.__SERVICES_INSTANCES: raise ServiceNotRegisterException( - "Cannot register %s as a dependency for %s because the dependency is not registered." - % (dependency_name, name) + f"Cannot register {dependency_name} as a dependency for {name} because the dependency is not registered." ) # Add an edge in the dependency graph dependency_graph[dependency_name].append(name) @@ -83,12 +85,18 @@ def inject_services(cls): if len(sorted_services) != len(cls.__SERVICES_REGISTRED): unresolved_services = set(cls.__SERVICES_REGISTRED) - set(sorted_services) raise ServiceCicularDependencyException( - "Circular dependency detected! The following services are involved in a circular dependency: %s" - % ', '.join(unresolved_services) + f"Circular dependency detected! The following services are involved in a circular dependency: {', '.join(unresolved_services)}" ) - # Step 4: Inject dependencies in the correct order - for service_name in sorted_services: + cls.__ORDERED_SERVICES = sorted_services + + @classmethod + def __inject_services(cls): + if not cls.__ORDERED_SERVICES: + raise Exception("Failed to determine service startup order.") + + # Inject dependencies in the correct order + for service_name in cls.__ORDERED_SERVICES: class_instance = cls.__SERVICES_INSTANCES[service_name] signature = inspect.signature(class_instance.injectService) method_parameters = list(signature.parameters.values()) @@ -101,6 +109,21 @@ def inject_services(cls): class_instance.injectService(*services_dependencies) + @classmethod + def __create_services(cls): + for name, service_type in cls.__SERVICES_REGISTRED.items(): + class_instance = service_type() + cls.__SERVICES_INSTANCES[name] = class_instance + + @classmethod + def __start_services(cls): + if not cls.__ORDERED_SERVICES: + raise Exception("Failed to determine service startup order.") + + for service_name in cls.__ORDERED_SERVICES: + class_instance = cls.__SERVICES_INSTANCES[service_name] + class_instance.start() + @classmethod def get_dependency_tree(cls): # Step 1: Build dependency graph @@ -145,11 +168,6 @@ def build_tree(node, prefix="", last=True): return "Services tree :\n" + "\n".join(buffer) - @classmethod - def start_services(cls): - for name, class_instance in cls.__SERVICES_INSTANCES.items(): - class_instance.start() - @classmethod def get_service_registred(cls, class_name: str) -> type[ServiceInjector]: if class_name not in cls.__SERVICES_REGISTRED: diff --git a/src/pyramid/data/functional/main.py b/src/pyramid/data/functional/main.py index 1294113..ac29397 100644 --- a/src/pyramid/data/functional/main.py +++ b/src/pyramid/data/functional/main.py @@ -30,10 +30,7 @@ def args(self): sys.exit(0) def start(self): - ServiceRegister.import_services() - ServiceRegister.create_services() - ServiceRegister.inject_services() - ServiceRegister.start_services() + ServiceRegister.enable() MainQueue.init() diff --git a/src/pyramid/data/functional/messages/message_sender_queued.py b/src/pyramid/data/functional/messages/message_sender_queued.py index 7dde574..432fd91 100644 --- a/src/pyramid/data/functional/messages/message_sender_queued.py +++ b/src/pyramid/data/functional/messages/message_sender_queued.py @@ -3,7 +3,7 @@ from pyramid.data.functional.messages.message_sender import MessageSender from discord import Interaction, Message, WebhookMessage from discord.utils import MISSING -from pyramid.tools.custom_queue import Queue, QueueItem +from pyramid.tools.custom_queue import QueueItem from pyramid.tools.main_queue import MainQueue MAX_MSG_LENGTH = 2000 diff --git a/src/pyramid/tools/utils.py b/src/pyramid/tools/utils.py index ccd6757..c1a1e86 100644 --- a/src/pyramid/tools/utils.py +++ b/src/pyramid/tools/utils.py @@ -41,6 +41,7 @@ def clear_directory(directory): if not os.path.exists(directory): return + i = 0 for filename in os.listdir(directory): if filename in ['.gitignore', '.gitkeep', '.dockerignore']: continue @@ -48,8 +49,10 @@ def clear_directory(directory): try: if os.path.isfile(file_path): os.unlink(file_path) + i = i + 1 except Exception as e: - logging.warning("Failed to delete %s due to %s", file_path, e) + logging.warning("Failed to delete file '%s' due to %s", file_path, e) + logging.info("Cleared %d files from folder '%s'." % (i, directory)) def split_string_by_length( From 1882d72e52153c90f98b89e19784dc3bb3098f4b Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Tue, 1 Oct 2024 03:07:14 +0200 Subject: [PATCH 19/32] fix: tests service import --- src/pyramid/api/services/tools/register.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pyramid/api/services/tools/register.py b/src/pyramid/api/services/tools/register.py index b90e934..acbd5a3 100644 --- a/src/pyramid/api/services/tools/register.py +++ b/src/pyramid/api/services/tools/register.py @@ -17,14 +17,14 @@ class ServiceRegister: @classmethod def enable(cls): - cls.__import_services() + cls.import_services() cls.__create_services() cls.__determine_service_order() cls.__inject_services() cls.__start_services() @classmethod - def __import_services(cls): + def import_services(cls): package_name = "pyramid.services" package = importlib.import_module(package_name) From bc2f415be0db8ed6c82b562864c73cbf6dfa6d45 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Thu, 3 Oct 2024 23:04:53 +0200 Subject: [PATCH 20/32] tests: added for service injection --- src/pyramid/api/services/tools/exceptions.py | 6 +- src/pyramid/api/services/tools/register.py | 122 ++++++---- src/pyramid/api/services/tools/tester.py | 16 +- src/pyramid/data/functional/main.py | 1 + tests/service_injection_test.py | 220 +++++++++++++++++++ tests/service_register_test.py | 93 ++++++++ tests/service_standalone_test.py | 52 +++++ 7 files changed, 453 insertions(+), 57 deletions(-) create mode 100644 tests/service_injection_test.py create mode 100644 tests/service_register_test.py create mode 100644 tests/service_standalone_test.py diff --git a/src/pyramid/api/services/tools/exceptions.py b/src/pyramid/api/services/tools/exceptions.py index d00c891..368c804 100644 --- a/src/pyramid/api/services/tools/exceptions.py +++ b/src/pyramid/api/services/tools/exceptions.py @@ -4,8 +4,12 @@ class ServiceRegisterException(Exception): class ServiceAlreadyRegisterException(ServiceRegisterException): pass -class ServiceNotRegisterException(ServiceRegisterException): +class ServiceNotRegisteredException(ServiceRegisterException): pass class ServiceCicularDependencyException(ServiceRegisterException): pass + +class ServiceWasNotOrdedException(ServiceRegisterException): + pass + diff --git a/src/pyramid/api/services/tools/register.py b/src/pyramid/api/services/tools/register.py index acbd5a3..a349a0c 100644 --- a/src/pyramid/api/services/tools/register.py +++ b/src/pyramid/api/services/tools/register.py @@ -3,25 +3,17 @@ import importlib import inspect import pkgutil -from typing import Any, Type, TypeVar -from pyramid.api.services.tools.exceptions import ServiceNotRegisterException, ServiceAlreadyRegisterException, ServiceCicularDependencyException +from typing import Any, Type, TypeVar, get_type_hints +from pyramid.api.services.tools.exceptions import ServiceNotRegisteredException, ServiceAlreadyRegisterException, ServiceCicularDependencyException, ServiceWasNotOrdedException from pyramid.api.services.tools.injector import ServiceInjector T = TypeVar('T') class ServiceRegister: - __SERVICES_REGISTRED: dict[str, type[ServiceInjector]] = {} - __SERVICES_INSTANCES: dict[str, ServiceInjector] = {} - __ORDERED_SERVICES: list[str] | None = None - - @classmethod - def enable(cls): - cls.import_services() - cls.__create_services() - cls.__determine_service_order() - cls.__inject_services() - cls.__start_services() + SERVICES_REGISTRED: dict[str, type[ServiceInjector]] = {} + SERVICES_INSTANCES: dict[str, ServiceInjector] = {} + ORDERED_SERVICES: list[str] | None = None @classmethod def import_services(cls): @@ -37,32 +29,45 @@ def register_service(cls, interface_name: str, type: type[object]): type_name = type.__name__ if not issubclass(type, ServiceInjector): raise TypeError("Service %s is not a subclass of ServiceInjector and cannot be initialized." % type_name) - if interface_name in cls.__SERVICES_REGISTRED: - already_class_name = cls.__SERVICES_REGISTRED[interface_name].__name__ + if interface_name in cls.SERVICES_REGISTRED: + already_class_name = cls.SERVICES_REGISTRED[interface_name].__module__ + ' ' + cls.SERVICES_REGISTRED[interface_name].__name__ raise ServiceAlreadyRegisterException( "Cannot register service %s with %s, it is already registered with the class %s." % (interface_name, type_name, already_class_name) ) - cls.__SERVICES_REGISTRED[interface_name] = type + cls.SERVICES_REGISTRED[interface_name] = type + + @classmethod + def enable(cls): + cls.create_services() + cls.determine_service_order() + cls.inject_services() + cls.start_services() @classmethod - def __determine_service_order(cls): + def determine_service_order(cls): + """This method is not recommended. + + Please call the `enable` method instead, which takes care of performing + the actions in the correct order. + """ # Step 1: Create a graph of dependencies dependency_graph = defaultdict(list) indegree = defaultdict(int) # To track the number of dependencies # Parse dependencies but delay injecting - for name, service_type in cls.__SERVICES_REGISTRED.items(): - class_instance = cls.__SERVICES_INSTANCES[name] + for name, service_type in cls.SERVICES_REGISTRED.items(): + class_instance = cls.SERVICES_INSTANCES[name] # Step 2: Parse dependencies for each service signature = inspect.signature(class_instance.injectService) method_parameters = list(signature.parameters.values()) + type_hints = get_type_hints(class_instance.injectService) for method_parameter in method_parameters: - dependency_name = method_parameter.annotation.__name__ - if dependency_name not in cls.__SERVICES_INSTANCES: - raise ServiceNotRegisterException( + dependency_name = type_hints[method_parameter.name].__name__ + if dependency_name not in cls.SERVICES_INSTANCES: + raise ServiceNotRegisteredException( f"Cannot register {dependency_name} as a dependency for {name} because the dependency is not registered." ) # Add an edge in the dependency graph @@ -71,7 +76,7 @@ def __determine_service_order(cls): # Step 3: Perform a topological sort to determine the order of instantiation sorted_services = [] - queue = deque([service for service in cls.__SERVICES_REGISTRED if indegree[service] == 0]) + queue = deque([service for service in cls.SERVICES_REGISTRED if indegree[service] == 0]) while queue: service = queue.popleft() @@ -82,59 +87,76 @@ def __determine_service_order(cls): if indegree[dependent] == 0: queue.append(dependent) - if len(sorted_services) != len(cls.__SERVICES_REGISTRED): - unresolved_services = set(cls.__SERVICES_REGISTRED) - set(sorted_services) + if len(sorted_services) != len(cls.SERVICES_REGISTRED): + unresolved_services = set(cls.SERVICES_REGISTRED) - set(sorted_services) raise ServiceCicularDependencyException( f"Circular dependency detected! The following services are involved in a circular dependency: {', '.join(unresolved_services)}" ) - cls.__ORDERED_SERVICES = sorted_services + cls.ORDERED_SERVICES = sorted_services @classmethod - def __inject_services(cls): - if not cls.__ORDERED_SERVICES: - raise Exception("Failed to determine service startup order.") + def inject_services(cls): + """This method is not recommended. + + Please call the `enable` method instead, which takes care of performing + the actions in the correct order. + """ + if not cls.ORDERED_SERVICES: + raise ServiceWasNotOrdedException("Failed to determine service startup order.") # Inject dependencies in the correct order - for service_name in cls.__ORDERED_SERVICES: - class_instance = cls.__SERVICES_INSTANCES[service_name] + for service_name in cls.ORDERED_SERVICES: + class_instance = cls.SERVICES_INSTANCES[service_name] signature = inspect.signature(class_instance.injectService) method_parameters = list(signature.parameters.values()) + type_hints = get_type_hints(class_instance.injectService) services_dependencies = [] for method_parameter in method_parameters: - dependency_name = method_parameter.annotation.__name__ - dependency_instance = cls.__SERVICES_INSTANCES[dependency_name] + dependency_name = type_hints[method_parameter.name].__name__ + dependency_instance = cls.SERVICES_INSTANCES[dependency_name] services_dependencies.append(dependency_instance) class_instance.injectService(*services_dependencies) @classmethod - def __create_services(cls): - for name, service_type in cls.__SERVICES_REGISTRED.items(): + def create_services(cls): + """This method is not recommended. + + Please call the `enable` method instead, which takes care of performing + the actions in the correct order. + """ + for name, service_type in cls.SERVICES_REGISTRED.items(): class_instance = service_type() - cls.__SERVICES_INSTANCES[name] = class_instance + cls.SERVICES_INSTANCES[name] = class_instance @classmethod - def __start_services(cls): - if not cls.__ORDERED_SERVICES: - raise Exception("Failed to determine service startup order.") + def start_services(cls): + """This method is not recommended. - for service_name in cls.__ORDERED_SERVICES: - class_instance = cls.__SERVICES_INSTANCES[service_name] + Please call the `enable` method instead, which takes care of performing + the actions in the correct order. + """ + if not cls.ORDERED_SERVICES: + raise ServiceWasNotOrdedException("Failed to determine service startup order.") + + for service_name in cls.ORDERED_SERVICES: + class_instance = cls.SERVICES_INSTANCES[service_name] class_instance.start() @classmethod def get_dependency_tree(cls): # Step 1: Build dependency graph dependency_graph = defaultdict(list) - for name, class_instance in cls.__SERVICES_INSTANCES.items(): + for name, class_instance in cls.SERVICES_INSTANCES.items(): signature = inspect.signature(class_instance.injectService) method_parameters = list(signature.parameters.values()) + type_hints = get_type_hints(class_instance.injectService) for method_parameter in method_parameters: - dependency_name = method_parameter.annotation.__name__ + dependency_name = type_hints[method_parameter.name].__name__ dependency_graph[dependency_name].append(name) # Step 2: Internal buffer for storing the tree structure @@ -155,7 +177,7 @@ def build_tree(node, prefix="", last=True): build_tree(child, prefix, i == len(children) - 1) # Step 4: Find root services (those with no dependencies) - all_services = set(cls.__SERVICES_REGISTRED.keys()) + all_services = set(cls.SERVICES_REGISTRED.keys()) dependent_services = set(dep for deps in dependency_graph.values() for dep in deps) root_services = all_services - dependent_services @@ -170,13 +192,17 @@ def build_tree(node, prefix="", last=True): @classmethod def get_service_registred(cls, class_name: str) -> type[ServiceInjector]: - if class_name not in cls.__SERVICES_REGISTRED: - raise ServiceNotRegisterException( + if class_name not in cls.SERVICES_REGISTRED: + raise ServiceNotRegisteredException( "Cannot get %s because the service is not registered." % (class_name) ) - return cls.__SERVICES_REGISTRED[class_name] + return cls.SERVICES_REGISTRED[class_name] @classmethod def get_service(cls, class_type: Type[T]) -> T: class_name = class_type.__name__ - return cls.__SERVICES_INSTANCES[class_name] + if class_name not in cls.SERVICES_INSTANCES: + raise ServiceNotRegisteredException( + "Cannot get %s because the service is not started." % (class_name) + ) + return cls.SERVICES_INSTANCES[class_name] # type: ignore diff --git a/src/pyramid/api/services/tools/tester.py b/src/pyramid/api/services/tools/tester.py index 40e17d0..19a4639 100644 --- a/src/pyramid/api/services/tools/tester.py +++ b/src/pyramid/api/services/tools/tester.py @@ -1,13 +1,12 @@ import inspect -from typing import Optional, Type, TypeVar, cast -from pyramid.api.services.tools.exceptions import ServiceNotRegisterException +from typing import Type, TypeVar, cast, get_type_hints from pyramid.api.services.tools.register import ServiceRegister T = TypeVar('T') class ServiceStandalone: - __SERVICE_REGISTERED: dict[str, object] = {} + SERVICE_REGISTERED: dict[str, object] = {} @classmethod def import_services(cls): @@ -16,24 +15,25 @@ def import_services(cls): @classmethod def set_service(cls, service_interface: Type[T], service_instance: object): service_name = service_interface.__name__ - cls.__SERVICE_REGISTERED[service_name] = service_instance + cls.SERVICE_REGISTERED[service_name] = service_instance @classmethod def get_service(cls, service_interface: Type[T]) -> T: service_name = service_interface.__name__ - if service_name in cls.__SERVICE_REGISTERED: - return cast(T, cls.__SERVICE_REGISTERED[service_name]) + if service_name in cls.SERVICE_REGISTERED: + return cast(T, cls.SERVICE_REGISTERED[service_name]) service_type = ServiceRegister.get_service_registred(service_name) class_instance = service_type() signature = inspect.signature(class_instance.injectService) method_parameters = list(signature.parameters.values()) + type_hints = get_type_hints(class_instance.injectService) services_dependencies = [] for method_parameter in method_parameters: - dependency = method_parameter.annotation + dependency = type_hints[method_parameter.name] dependency_instance = cls.get_service(dependency) services_dependencies.append(dependency_instance) @@ -41,6 +41,6 @@ def get_service(cls, service_interface: Type[T]) -> T: class_instance.injectService(*services_dependencies) class_instance.start() - cls.__SERVICE_REGISTERED[service_name] = class_instance + cls.SERVICE_REGISTERED[service_name] = class_instance return cast(T, class_instance) diff --git a/src/pyramid/data/functional/main.py b/src/pyramid/data/functional/main.py index ac29397..dc5ad67 100644 --- a/src/pyramid/data/functional/main.py +++ b/src/pyramid/data/functional/main.py @@ -30,6 +30,7 @@ def args(self): sys.exit(0) def start(self): + ServiceRegister.import_services() ServiceRegister.enable() MainQueue.init() diff --git a/tests/service_injection_test.py b/tests/service_injection_test.py new file mode 100644 index 0000000..43e631b --- /dev/null +++ b/tests/service_injection_test.py @@ -0,0 +1,220 @@ +import unittest +from unittest.mock import patch, Mock +from pyramid.api.services.tools.exceptions import ServiceCicularDependencyException, ServiceNotRegisteredException +from pyramid.api.services.tools.register import ServiceRegister +from pyramid.api.services.tools.injector import ServiceInjector + +class AlphaService(ServiceInjector): + def __init__(self): + self.started = False + + def start(self): + self.started = True + +class BravoService(ServiceInjector): + def __init__(self): + self.started = False + + def injectService(self, alpha_service: AlphaService): + self.alpha_service = alpha_service + + def start(self): + self.started = True + +class CharlieService(ServiceInjector): + def __init__(self): + self.started = False + + def start(self): + self.started = True + +class DeltaService(ServiceInjector): + def __init__(self): + self.started = False + + def injectService(self, alpha_service: AlphaService, bravo_service: BravoService, foxtrot_service: 'FoxtrotService'): + self.alpha_service = alpha_service + self.bravo_service = bravo_service + self.foxtrot_service = foxtrot_service + + def start(self): + self.started = True + +class EchoService(ServiceInjector): + def __init__(self): + self.started = False + + def injectService(self, alpha_service: AlphaService, charlie_service: CharlieService): + self.alpha_service = alpha_service + self.charlie_service = charlie_service + + def start(self): + self.started = True + +class FoxtrotService(ServiceInjector): + def __init__(self): + self.started = False + + def injectService(self, echo_service: EchoService): + self.echo_service = echo_service + + def start(self): + self.started = True + +class GolfService(ServiceInjector): + def __init__(self): + self.started = False + + def start(self): + self.started = True + +class AlphaCircularService(ServiceInjector): + def __init__(self): + self.started = False + + def injectService(self, bravo_circular_service: 'BravoCircularService'): + self.bravo_circular_service = bravo_circular_service + + def start(self): + pass + +class BravoCircularService(ServiceInjector): + def __init__(self): + self.started = False + + def injectService(self, alpha_circular_service: AlphaCircularService): + self.alpha_circular_service = alpha_circular_service + + def start(self): + pass + +class TestServiceInjectionOrder(unittest.TestCase): + + def setUp(self): + ServiceRegister.SERVICES_REGISTRED.clear() + ServiceRegister.SERVICES_INSTANCES.clear() + ServiceRegister.ORDERED_SERVICES = None + + @patch.object(CharlieService, 'start', autospec=True) + @patch.object(BravoService, 'start', autospec=True) + @patch.object(AlphaService, 'start', autospec=True) + def test_simple_injection_order(self, mock_alpha_start: Mock, mock_bravo_start: Mock, mock_charlie_start: Mock): + ServiceRegister.register_service(AlphaService.__name__, AlphaService) + ServiceRegister.register_service(BravoService.__name__, BravoService) + ServiceRegister.register_service(CharlieService.__name__, CharlieService) + + test_instance = self + + def alpha_start(self: AlphaService): + self.started = True + bravo = ServiceRegister.get_service(BravoService) + test_instance.assertFalse(bravo.started) + mock_alpha_start.side_effect = alpha_start + + def bravo_start(self: BravoService): + self.started = True + test_instance.assertTrue(self.alpha_service.started) + mock_bravo_start.side_effect = bravo_start + + def charlie_start(self: CharlieService): + self.started = True + alpha = ServiceRegister.get_service(AlphaService) + bravo = ServiceRegister.get_service(BravoService) + # Alpha was registered first + test_instance.assertTrue(alpha.started) + # Bravo was registered first but had a dependency, + # whereas Charlie was registered later and had no dependencies. + test_instance.assertFalse(bravo.started) + mock_charlie_start.side_effect = charlie_start + + ServiceRegister.enable() + + @patch.object(GolfService, 'start', autospec=True) + @patch.object(FoxtrotService, 'start', autospec=True) + @patch.object(EchoService, 'start', autospec=True) + @patch.object(DeltaService, 'start', autospec=True) + @patch.object(CharlieService, 'start', autospec=True) + @patch.object(BravoService, 'start', autospec=True) + @patch.object(AlphaService, 'start', autospec=True) + def test_complex_injection_order(self, mock_alpha_start: Mock, mock_bravo_start: Mock, + mock_charlie_start: Mock, mock_delta_start: Mock, mock_echo_start: Mock, + mock_foxtrot_start: Mock, mock_golf_start: Mock + ): + ServiceRegister.register_service(AlphaService.__name__, AlphaService) + ServiceRegister.register_service(BravoService.__name__, BravoService) + ServiceRegister.register_service(CharlieService.__name__, CharlieService) + ServiceRegister.register_service(DeltaService.__name__, DeltaService) + ServiceRegister.register_service(EchoService.__name__, EchoService) + ServiceRegister.register_service(FoxtrotService.__name__, FoxtrotService) + ServiceRegister.register_service(GolfService.__name__, FoxtrotService) + + test_instance = self + + def alpha_start(self: AlphaService): + self.started = True + bravo = ServiceRegister.get_service(BravoService) + test_instance.assertFalse(bravo.started) + delta = ServiceRegister.get_service(DeltaService) + test_instance.assertFalse(delta.started) + echo = ServiceRegister.get_service(EchoService) + test_instance.assertFalse(echo.started) + mock_alpha_start.side_effect = alpha_start + + def bravo_start(self: BravoService): + self.started = True + test_instance.assertTrue(self.alpha_service.started) + mock_bravo_start.side_effect = bravo_start + + def charlie_start(self: CharlieService): + self.started = True + echo = ServiceRegister.get_service(EchoService) + test_instance.assertFalse(echo.started) + mock_charlie_start.side_effect = charlie_start + + def delta_start(self: DeltaService): + self.started = True + test_instance.assertTrue(self.alpha_service.started) + test_instance.assertTrue(self.bravo_service.started) + test_instance.assertTrue(self.foxtrot_service.started) + mock_delta_start.side_effect = delta_start + + def echo_start(self: EchoService): + self.started = True + test_instance.assertTrue(self.alpha_service.started) + test_instance.assertTrue(self.charlie_service.started) + foxtrot = ServiceRegister.get_service(FoxtrotService) + test_instance.assertFalse(foxtrot.started) + mock_echo_start.side_effect = echo_start + + def foxtrot_start(self: FoxtrotService): + self.started = True + test_instance.assertTrue(self.echo_service.started) + delta = ServiceRegister.get_service(DeltaService) + test_instance.assertFalse(delta.started) + mock_foxtrot_start.side_effect = foxtrot_start + + def golf_start(self: GolfService): + self.started = True + mock_golf_start.side_effect = golf_start + + ServiceRegister.enable() + + def test_circular_dependency(self): + ServiceRegister.register_service(AlphaCircularService.__name__, AlphaCircularService) + ServiceRegister.register_service(BravoCircularService.__name__, BravoCircularService) + + ServiceRegister.create_services() + + with self.assertRaises(ServiceCicularDependencyException): + ServiceRegister.determine_service_order() + + def test_missing_dependency(self): + ServiceRegister.register_service(BravoService.__name__, BravoService) + + ServiceRegister.create_services() + + with self.assertRaises(ServiceNotRegisteredException): + ServiceRegister.determine_service_order() + +if __name__ == "__main__": + unittest.main() diff --git a/tests/service_register_test.py b/tests/service_register_test.py new file mode 100644 index 0000000..2676add --- /dev/null +++ b/tests/service_register_test.py @@ -0,0 +1,93 @@ +from abc import ABC +from typing import cast +import unittest +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.exceptions import ServiceAlreadyRegisterException, ServiceNotRegisteredException +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.api.services.tools.register import ServiceRegister + +class IService(ABC): + def __init__(self): + self.name = "Carrot" + +class AlphaService(ServiceInjector): + pass + +class BravoService(IService, ServiceInjector): + def __init__(self): + self.age = 21 + self.name = "Broccoli" + +class CharlieService(IService, ServiceInjector): + pass + +class DeltaService(ServiceInjector): + pass + +class EchoService(): + pass + +class FoxtrotService(ServiceInjector): + pass + +class ServiceRegisterDecoratorTest(unittest.TestCase): + + def setUp(self): + ServiceRegister.SERVICES_REGISTRED.clear() + ServiceRegister.SERVICES_INSTANCES.clear() + ServiceRegister.ORDERED_SERVICES = None + ServiceRegister.register_service(AlphaService.__name__, AlphaService) + ServiceRegister.register_service(IService.__name__, BravoService) + + def test_service_registration(self): + self.assertIsNotNone(ServiceRegister.get_service_registred(AlphaService.__name__)) + + def test_service_interface_registration(self): + self.assertIsNotNone(ServiceRegister.get_service_registred(IService.__name__)) + + def test_service_interface(self): + ServiceRegister.create_services() + iservice = ServiceRegister.get_service(IService) + beta_service = cast(BravoService, iservice) + self.assertEqual(beta_service.age, 21) + + def test_service_interface_override(self): + ServiceRegister.create_services() + iservice = ServiceRegister.get_service(IService) + beta_service = cast(BravoService, iservice) + self.assertEqual(beta_service.name, "Broccoli") + + def test_register_duplicate_service(self): + ServiceRegister.register_service(CharlieService.__name__, CharlieService) + with self.assertRaises(ServiceAlreadyRegisterException): + ServiceRegister.register_service(CharlieService.__name__, CharlieService) + + def test_register_duplicate_interface(self): + with self.assertRaises(ServiceAlreadyRegisterException): + ServiceRegister.register_service(IService.__name__, DeltaService) + + def test_register_not_baseclass(self): + with self.assertRaises(TypeError): + ServiceRegister.register_service(EchoService.__name__, EchoService) + + def test_can_get(self): + ServiceRegister.create_services() + ServiceRegister.get_service(AlphaService) + + def test_not_register(self): + ServiceRegister.create_services() + with self.assertRaises(ServiceNotRegisteredException): + ServiceRegister.get_service(FoxtrotService) + + def test_not_register_str(self): + ServiceRegister.create_services() + with self.assertRaises(ServiceNotRegisteredException): + ServiceRegister.get_service_registred(FoxtrotService.__class__.__name__) + + def test_class_not_register(self): + ServiceRegister.create_services() + with self.assertRaises(ServiceNotRegisteredException): + ServiceRegister.get_service(BravoService) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/service_standalone_test.py b/tests/service_standalone_test.py new file mode 100644 index 0000000..6fc7ac8 --- /dev/null +++ b/tests/service_standalone_test.py @@ -0,0 +1,52 @@ +import unittest +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.api.services.tools.register import ServiceRegister +from pyramid.api.services.tools.tester import ServiceStandalone + +class AlphaService(ServiceInjector): + def __init__(self): + self.started = False + + def start(self): + self.started = True + +class BravoService(ServiceInjector): + def __init__(self): + self.started = False + + def injectService(self, alpha_service: AlphaService): + self.alpha_service = alpha_service + + def start(self): + self.started = True + +class CharlieService(AlphaService): + pass + +class ServiceStandaloneTest(unittest.TestCase): + + def setUp(self): + ServiceRegister.SERVICES_REGISTRED.clear() + ServiceRegister.SERVICES_INSTANCES.clear() + ServiceRegister.ORDERED_SERVICES = None + ServiceRegister.register_service(AlphaService.__name__, AlphaService) + ServiceRegister.register_service(BravoService.__name__, BravoService) + + def test_set_service(self): + ServiceRegister.enable() + + charlie = CharlieService() + ServiceStandalone.set_service(AlphaService, charlie) + + bravo = ServiceStandalone.get_service(BravoService) + self.assertEqual(bravo.alpha_service, charlie) + + def test_get_service(self): + service_instance = AlphaService() + ServiceStandalone.set_service(AlphaService, service_instance) + result = ServiceStandalone.get_service(AlphaService) + self.assertEqual(result, service_instance) + +if __name__ == "__main__": + unittest.main() From 8d25ee1ba32fb4424db1fa2799f824a1a342c708 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Tue, 8 Oct 2024 00:58:38 +0200 Subject: [PATCH 21/32] feat: adding message queue in a task --- .vscode/launch.json | 8 ++ src/pyramid/api/services/message_queue.py | 22 ++++ src/pyramid/api/tasks/tools/injector.py | 3 - src/pyramid/api/tasks/tools/register.py | 20 ++- .../discord/commands/test_message_queue.py | 33 +++++ src/pyramid/connector/discord/guild_cmd.py | 4 +- .../connector/discord/guild_cmd_tools.py | 4 +- .../{guild_queue.py => music_queue.py} | 4 +- .../messages/message_sender_queued.py | 2 +- src/pyramid/data/guild_instance.py | 4 +- src/pyramid/data/message_queue_item.py | 18 +++ .../data/{a_guid_queue.py => music_queue.py} | 2 +- src/pyramid/data/queue_item.py | 20 +++ src/pyramid/services/discord.py | 1 - src/pyramid/services/message_queue.py | 118 ++++++++++++++++++ src/pyramid/services/socket_server.py | 2 +- src/pyramid/services/todo.yaml | 5 + src/pyramid/tasks/message_queue.py | 17 +++ src/pyramid/tools/custom_queue.py | 22 +--- tests/queue_test.py | 3 +- 20 files changed, 275 insertions(+), 37 deletions(-) create mode 100644 src/pyramid/api/services/message_queue.py create mode 100644 src/pyramid/connector/discord/commands/test_message_queue.py rename src/pyramid/connector/discord/{guild_queue.py => music_queue.py} (98%) create mode 100644 src/pyramid/data/message_queue_item.py rename src/pyramid/data/{a_guid_queue.py => music_queue.py} (96%) create mode 100644 src/pyramid/data/queue_item.py create mode 100644 src/pyramid/services/message_queue.py create mode 100644 src/pyramid/services/todo.yaml create mode 100644 src/pyramid/tasks/message_queue.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 92e4591..3339369 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -32,6 +32,14 @@ { "localRoot": "${workspaceFolder}", "remoteRoot": "/app" + }, + { + "localRoot": "${workspaceFolder}/.venv", + "remoteRoot": "/opt/venv" + }, + { + "localRoot": "/usr/lib/python3.12", + "remoteRoot": "/usr/local/lib/python3.12" } ], "justMyCode": false diff --git a/src/pyramid/api/services/message_queue.py b/src/pyramid/api/services/message_queue.py new file mode 100644 index 0000000..b225cef --- /dev/null +++ b/src/pyramid/api/services/message_queue.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from pyramid.data.message_queue_item import MessageQueueItem + +class IMessageQueueService(ABC): + + @abstractmethod + async def worker(self) -> None: + pass + + @abstractmethod + def stop(self) -> None: + pass + + @abstractmethod + def add(self, item: MessageQueueItem, unique_id: Optional[str] = None, priority: int = 5) -> str: + pass + + @abstractmethod + def add_first(self, item: MessageQueueItem, unique_id: Optional[str] = None) -> str: + pass diff --git a/src/pyramid/api/tasks/tools/injector.py b/src/pyramid/api/tasks/tools/injector.py index 8aad739..d99631e 100644 --- a/src/pyramid/api/tasks/tools/injector.py +++ b/src/pyramid/api/tasks/tools/injector.py @@ -11,6 +11,3 @@ async def worker_asyc(self): async def stop_asyc(self): pass - - # def worker(self): - # pass diff --git a/src/pyramid/api/tasks/tools/register.py b/src/pyramid/api/tasks/tools/register.py index c551079..2230207 100644 --- a/src/pyramid/api/tasks/tools/register.py +++ b/src/pyramid/api/tasks/tools/register.py @@ -49,13 +49,27 @@ def inject_tasks(cls): @classmethod def __handle_signal(cls, signum: int, frame): - logging.info("Received signal %d. shutting down ..." % signum) + logging.info("Received signal %d." % signum) + cls.stop() + @classmethod + def stop(cls): + logging.info("Shutting down tasks ...") + test = cls + loop = asyncio.get_event_loop() for name, parameters in cls.__TASKS_REGISTERED.items(): async def shutdown(loop: asyncio.AbstractEventLoop): + if loop.is_closed(): + logging.warning("Loop of task %s is already closed." % name) + return + logging.info("Task %s stopping..." % name) await parameters.cls_instance.stop_asyc() loop.stop() - asyncio.run_coroutine_threadsafe(shutdown(parameters.loop), parameters.loop) + if (name == "DiscordTask"): + asyncio.run_coroutine_threadsafe(shutdown(parameters.loop), parameters.loop) + else: + result = loop.run_until_complete(shutdown(parameters.loop)) + logging.info("Task %s have been asked to stop." % name) @classmethod def start_tasks(cls): @@ -79,7 +93,9 @@ def running(loop: asyncio.AbstractEventLoop): parameters.thread.start() for name, parameters in cls.__TASKS_REGISTERED.items(): + logging.info("JOIN %s" % name) parameters.thread.join() + logging.info("STOP JOIN %s" % name) signal.signal(signal.SIGTERM, previous_handler) logging.info("All registered tasks are stopped") diff --git a/src/pyramid/connector/discord/commands/test_message_queue.py b/src/pyramid/connector/discord/commands/test_message_queue.py new file mode 100644 index 0000000..d3cd965 --- /dev/null +++ b/src/pyramid/connector/discord/commands/test_message_queue.py @@ -0,0 +1,33 @@ + + +from discord import Interaction + +from pyramid.api.services.discord import IDiscordService +from pyramid.api.services.message_queue import IMessageQueueService +from pyramid.connector.discord.commands.tools.abc import AbstractCommand +from pyramid.connector.discord.commands.tools.annotation import discord_command +from pyramid.connector.discord.commands.tools.parameters import ParametersCommand +from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued +from pyramid.data.message_queue_item import MessageQueueItem + + +@discord_command(parameters=ParametersCommand()) +class TestMessageQueueCommand(AbstractCommand): + + def injectService(self, + message_queue_service: IMessageQueueService + ): + self.__message_queue_service = message_queue_service + + async def execute(self, ctx: Interaction, nb: int = 100, reuse: bool = False, priority: bool = False): + # async def execute(self, ctx: Interaction, nb: int = 100, reuse: bool = False): + await ctx.response.defer(thinking=True) + + id = None + for i in range(1, nb + 1): + if priority: + id = self.__message_queue_service.add(MessageQueueItem(ctx.followup.send, ctx.client.loop, content=f"Priority n°{i}"), id, 1) + else: + id = self.__message_queue_service.add(MessageQueueItem(ctx.followup.send, ctx.client.loop, content=f"Hello n°{i}"), id) + if not reuse: + id = None diff --git a/src/pyramid/connector/discord/guild_cmd.py b/src/pyramid/connector/discord/guild_cmd.py index c053bb5..0875b91 100644 --- a/src/pyramid/connector/discord/guild_cmd.py +++ b/src/pyramid/connector/discord/guild_cmd.py @@ -9,7 +9,7 @@ from pyramid.data.source_type import SourceType from pyramid.data.music.track_minimal import TrackMinimal from pyramid.connector.discord.guild_cmd_tools import GuildCmdTools -from pyramid.connector.discord.guild_queue import GuildQueue +from pyramid.connector.discord.music_queue import MusicQueue from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued from pyramid.data.exceptions import DiscordMessageException from pyramid.data.a_guild_cmd import AGuildCmd @@ -21,7 +21,7 @@ def __init__( self, logger: Logger, guild_data: GuildData, - guild_queue: GuildQueue, + guild_queue: MusicQueue, engine_source: ISourceService, ): self.logger = logger diff --git a/src/pyramid/connector/discord/guild_cmd_tools.py b/src/pyramid/connector/discord/guild_cmd_tools.py index 489cb9a..a27c5f3 100644 --- a/src/pyramid/connector/discord/guild_cmd_tools.py +++ b/src/pyramid/connector/discord/guild_cmd_tools.py @@ -11,7 +11,7 @@ from pyramid.data.music.track_minimal import TrackMinimal from pyramid.data.guild_data import GuildData from pyramid.data.tracklist import TrackList -from pyramid.connector.discord.guild_queue import GuildQueue +from pyramid.connector.discord.music_queue import MusicQueue from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued @@ -19,7 +19,7 @@ class GuildCmdTools: def __init__( self, guild_data: GuildData, - guild_queue: GuildQueue, + guild_queue: MusicQueue, engine_source: ISourceService, ): self.engine_source = engine_source diff --git a/src/pyramid/connector/discord/guild_queue.py b/src/pyramid/connector/discord/music_queue.py similarity index 98% rename from src/pyramid/connector/discord/guild_queue.py rename to src/pyramid/connector/discord/music_queue.py index c295fa2..ac13374 100644 --- a/src/pyramid/connector/discord/guild_queue.py +++ b/src/pyramid/connector/discord/music_queue.py @@ -9,10 +9,10 @@ from pyramid.data.guild_data import GuildData from pyramid.connector.discord.music.player_interface import MusicPlayerInterface from pyramid.data.functional.messages.message_sender_queued import MessageSenderQueued -from pyramid.data.a_guid_queue import AGuildQueue +from pyramid.data.music_queue import IMusicQueue -class GuildQueue(AGuildQueue): +class MusicQueue(IMusicQueue): def __init__(self, data: GuildData, ffmpeg_path: str, mpi: MusicPlayerInterface): self.data: GuildData = data self.ffmpeg = ffmpeg_path diff --git a/src/pyramid/data/functional/messages/message_sender_queued.py b/src/pyramid/data/functional/messages/message_sender_queued.py index 432fd91..4f47906 100644 --- a/src/pyramid/data/functional/messages/message_sender_queued.py +++ b/src/pyramid/data/functional/messages/message_sender_queued.py @@ -3,7 +3,7 @@ from pyramid.data.functional.messages.message_sender import MessageSender from discord import Interaction, Message, WebhookMessage from discord.utils import MISSING -from pyramid.tools.custom_queue import QueueItem +from pyramid.data.queue_item import QueueItem from pyramid.tools.main_queue import MainQueue MAX_MSG_LENGTH = 2000 diff --git a/src/pyramid/data/guild_instance.py b/src/pyramid/data/guild_instance.py index 8ccf68e..304faed 100644 --- a/src/pyramid/data/guild_instance.py +++ b/src/pyramid/data/guild_instance.py @@ -5,7 +5,7 @@ from pyramid.api.services.source_service import ISourceService from pyramid.connector.discord.guild_cmd import GuildCmd -from pyramid.connector.discord.guild_queue import GuildQueue +from pyramid.connector.discord.music_queue import MusicQueue from pyramid.connector.discord.music.player_interface import MusicPlayerInterface from pyramid.data.guild_data import GuildData @@ -14,6 +14,6 @@ class GuildInstances: def __init__(self, guild: Guild, logger: Logger, source_service: ISourceService, ffmpeg_path: str): self.data = GuildData(guild, source_service) self.mpi = MusicPlayerInterface(self.data.guild.preferred_locale, self.data.track_list) - self.songs = GuildQueue(self.data, ffmpeg_path, self.mpi) + self.songs = MusicQueue(self.data, ffmpeg_path, self.mpi) self.cmds = GuildCmd(logger, self.data, self.songs, source_service) self.mpi.set_queue_action(self.cmds) diff --git a/src/pyramid/data/message_queue_item.py b/src/pyramid/data/message_queue_item.py new file mode 100644 index 0000000..b682c6f --- /dev/null +++ b/src/pyramid/data/message_queue_item.py @@ -0,0 +1,18 @@ +import asyncio +from typing import Any, Callable + + +class MessageQueueItem: + def __init__( + self, + func: Callable, + loop: asyncio.AbstractEventLoop, + func_success: Callable | None = None, + func_error: Callable[[Exception], Any] | None = None, + **kwargs, + ) -> None: + self.func: Callable = func + self.loop = loop + self.func_success = func_success + self.func_error = func_error + self.kwargs = kwargs diff --git a/src/pyramid/data/a_guid_queue.py b/src/pyramid/data/music_queue.py similarity index 96% rename from src/pyramid/data/a_guid_queue.py rename to src/pyramid/data/music_queue.py index c85ab30..304a9de 100644 --- a/src/pyramid/data/a_guid_queue.py +++ b/src/pyramid/data/music_queue.py @@ -4,7 +4,7 @@ from pyramid.data.music.track import Track -class AGuildQueue(ABC): +class IMusicQueue(ABC): @abstractmethod def is_playing(self) -> bool: pass diff --git a/src/pyramid/data/queue_item.py b/src/pyramid/data/queue_item.py new file mode 100644 index 0000000..6031b39 --- /dev/null +++ b/src/pyramid/data/queue_item.py @@ -0,0 +1,20 @@ +import asyncio +from typing import Any, Callable + + +class QueueItem: + def __init__( + self, + name, + func: Callable, + loop: asyncio.AbstractEventLoop | None = None, + func_sucess: Callable | None = None, + func_error: Callable[[Exception], Any] | None = None, + **kwargs, + ) -> None: + self.name = name + self.func: Callable = func + self.loop = loop + self.func_sucess = func_sucess + self.func_error = func_error + self.kwargs = kwargs diff --git a/src/pyramid/services/discord.py b/src/pyramid/services/discord.py index 21bc09c..026271d 100644 --- a/src/pyramid/services/discord.py +++ b/src/pyramid/services/discord.py @@ -44,7 +44,6 @@ def injectService(self, self.__source_service = source_service def start(self): - intents = discord.Intents.default() # intents.members = True intents.message_content = True diff --git a/src/pyramid/services/message_queue.py b/src/pyramid/services/message_queue.py new file mode 100644 index 0000000..2458aaa --- /dev/null +++ b/src/pyramid/services/message_queue.py @@ -0,0 +1,118 @@ +import asyncio +import inspect +from collections import deque +from threading import Event +from typing import Callable, Deque, Optional +import uuid +from pyramid.api.services.discord import IDiscordService +from pyramid.api.services.logger import ILoggerService +from pyramid.api.services.message_queue import IMessageQueueService +from pyramid.api.services.tools.annotation import pyramid_service +from pyramid.api.services.tools.injector import ServiceInjector +from pyramid.data.message_queue_item import MessageQueueItem + +@pyramid_service(interface=IMessageQueueService) +class MessageQueueService(ServiceInjector): + + MIN_PRIORITY = 1 + MAX_PRIORITY = 3 + + def __init__(self): + self.queue: dict[str, MessageQueueItem] = {} + self.order: dict[int, Deque[str]] = {i: deque() for i in range(self.MIN_PRIORITY, self.MAX_PRIORITY + 1)} + self.event = Event() + self.open = True + + def injectService(self, + logger_service: ILoggerService, + discord_service: IDiscordService + ): + self.logger = logger_service + self.__discord_service = discord_service + self.logger.info("MessageQueueService had IDiscordService %#x" % (id(self.__discord_service))) + + def add(self, item: MessageQueueItem, unique_id: Optional[str] = None, priority: int = (MIN_PRIORITY + MAX_PRIORITY) // 2) -> str: + if priority < self.MIN_PRIORITY or priority > self.MAX_PRIORITY: + raise ValueError("Priority must be between %d and %d." % (self.MIN_PRIORITY, self.MAX_PRIORITY)) + + + if unique_id is None: + unique_id = self.generate_unique_id() + elif unique_id in self.queue: + self.queue[unique_id] = item + for p in range(self.MIN_PRIORITY, self.MAX_PRIORITY + 1): + if unique_id in self.order[p]: + if p != priority: + self.order[p].remove(unique_id) + self.order[priority].append(unique_id) + return unique_id + + self.queue[unique_id] = item + self.order[priority].append(unique_id) + self.event.set() + return unique_id + + async def worker(self) -> None: + self.logger.info("Message queue started") + while self.open: + self.event.wait() + if not self.open: + self.logger.info("Message queue closed, stopping worker") + break + item_poped = False + for priority in range(self.MIN_PRIORITY, self.MAX_PRIORITY + 1): + if not self.order[priority]: + continue + unique_id = self.order[priority].popleft() + item = self.queue.pop(unique_id, None) + if not item: + self.logger.info("Message queue has no items, stopping worker") + break + item_poped = True + + try: + result = self.__run_task(item.func, item.loop, **item.kwargs) + if item.func_success: + item.func_success(result) + except Exception as err: + if item.func_error: + item.func_error(err) + if not item_poped and not self.queue: + self.event.clear() + self.logger.info("Message queue stopped") + + def __run_task(self, func: Callable, loop: asyncio.AbstractEventLoop, **kwargs): + # Async func + if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): + # Async func in loop + if loop is not None: + # Async func in loop closed + if loop.is_closed(): + raise Exception( + "Unable to call %s.%s cause the loop is closed", + func.__module__, + func.__qualname__, + ) + # Async func in loop open + result = asyncio.run_coroutine_threadsafe(func(**kwargs), loop).result() + # Async func classic + else: + result = asyncio.run(func(**kwargs)) + # Sync func + else: + result = func(**kwargs) + return result + + def stop(self) -> None: + if not self.open: + self.logger.warning("Message queue already stopped") + return + self.logger.info("Message queue stopping") + self.open = False + self.event.set() + + def generate_unique_id(self) -> str: + unique_id = None + while unique_id is None or unique_id in self.queue: + unique_id = str(uuid.uuid4()) + return unique_id diff --git a/src/pyramid/services/socket_server.py b/src/pyramid/services/socket_server.py index a559760..30091ef 100644 --- a/src/pyramid/services/socket_server.py +++ b/src/pyramid/services/socket_server.py @@ -76,7 +76,7 @@ def close(self): self.is_running = False if self.server_socket is None: return - self.server_socket.shutdown(socket.SHUT_RDWR) + # self.server_socket.shutdown(socket.SHUT_RDWR) self.server_socket.close() self.server_socket = None self.__logger.info("Socket server stop") diff --git a/src/pyramid/services/todo.yaml b/src/pyramid/services/todo.yaml new file mode 100644 index 0000000..6f8e076 --- /dev/null +++ b/src/pyramid/services/todo.yaml @@ -0,0 +1,5 @@ +# TODO +Verify stop of tasks +- Message Queue Task should be ok +- Discord is ok but his stop is not like other. Add it to params to handle it properly +- Socket server sometimes never stop. Look into it \ No newline at end of file diff --git a/src/pyramid/tasks/message_queue.py b/src/pyramid/tasks/message_queue.py new file mode 100644 index 0000000..338229f --- /dev/null +++ b/src/pyramid/tasks/message_queue.py @@ -0,0 +1,17 @@ +from pyramid.api.services.message_queue import IMessageQueueService +from pyramid.api.tasks.tools.annotation import pyramid_task +from pyramid.api.tasks.tools.injector import TaskInjector +from pyramid.api.tasks.tools.parameters import ParametersTask + + +@pyramid_task(parameters=ParametersTask()) +class MessageQueueTask(TaskInjector): + + def injectService(self, __message_queue_service: IMessageQueueService): + self.__message_queue_service = __message_queue_service + + async def worker_asyc(self): + await self.__message_queue_service.worker() + + async def stop_asyc(self): + self.__message_queue_service.stop() diff --git a/src/pyramid/tools/custom_queue.py b/src/pyramid/tools/custom_queue.py index 8eaa543..7ea2883 100644 --- a/src/pyramid/tools/custom_queue.py +++ b/src/pyramid/tools/custom_queue.py @@ -5,25 +5,9 @@ from collections import deque from concurrent.futures import CancelledError from threading import Event, Lock, Thread -from typing import Any, Callable, Deque, List - - -class QueueItem: - def __init__( - self, - name, - func: Callable, - loop: asyncio.AbstractEventLoop | None = None, - func_sucess: Callable | None = None, - func_error: Callable[[Exception], Any] | None = None, - **kwargs, - ) -> None: - self.name = name - self.func: Callable = func - self.loop = loop - self.func_sucess = func_sucess - self.func_error = func_error - self.kwargs = kwargs +from typing import Callable, Deque, List + +from pyramid.data.queue_item import QueueItem def worker(q: Deque[QueueItem], thread_id: int, lock: Lock, event: Event): diff --git a/tests/queue_test.py b/tests/queue_test.py index a78ea82..4e54b82 100644 --- a/tests/queue_test.py +++ b/tests/queue_test.py @@ -1,7 +1,8 @@ import time import unittest -from pyramid.tools.custom_queue import Queue, QueueItem # noqa: E402 +from pyramid.data.queue_item import QueueItem +from pyramid.tools.custom_queue import Queue class SimpleQueue(unittest.TestCase): From c8a672f07fc9e9562c2e0c9905457d3e07f7cc2b Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Tue, 8 Oct 2024 01:55:24 +0200 Subject: [PATCH 22/32] feat: use select in socket --- src/pyramid/services/socket_server.py | 103 ++++++++++++++++---------- 1 file changed, 62 insertions(+), 41 deletions(-) diff --git a/src/pyramid/services/socket_server.py b/src/pyramid/services/socket_server.py index 30091ef..bc8f5e3 100644 --- a/src/pyramid/services/socket_server.py +++ b/src/pyramid/services/socket_server.py @@ -1,4 +1,4 @@ -import asyncio +import select import socket from socket import socket as sock from typing import Any @@ -20,7 +20,7 @@ def __init__(self) -> None: self.__host = "0.0.0.0" self.__port = self.__common.port self.is_running = False - self.server_socket: sock | None = None + self.server: sock | None = None def injectService(self, logger_service: ILoggerService @@ -28,57 +28,70 @@ def injectService(self, self.__logger = logger_service async def open(self): - self.server_socket = sock(socket.AF_INET, socket.SOCK_STREAM) - # self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.server_socket.bind((self.__host, self.__port)) - self.server_socket.listen(10) + self.server = sock(socket.AF_INET, socket.SOCK_STREAM) + self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.server.bind((self.__host, self.__port)) + self.server.listen(10) + self.server.setblocking(False) self.__logger.info("Socket server open on %s:%d", self.__host, self.__port) self.is_running = True - client_socket: sock | None = None - client_address: Any = None - client_ip: Any = None - client_port: Any = None while self.is_running: - try: - client_socket, client_address = self.server_socket.accept() - client_ip = client_address[0] - client_port = client_address[1] - response_to_send = await self.__handle_client(client_socket, client_ip, client_port) - if response_to_send: - # Convert the response data to JSON - response_json = SocketCommon.serialize( - response_to_send.to_json(SocketCommon.serialize) - ) - - # Send the JSON response back to the client - # self.__logger.debug("[%s:%d] <- %s", client_ip, client_port, response_json) - self.__common.send_chunk(client_socket, response_json) - except Exception as err: - if isinstance(err, OSError): - if err.errno == 9: + client_socket: sock | None = None + client_address: Any = None + client_ip: Any = None + client_port: Any = None + response_to_send: SocketResponse | None = None + response_json: str | None = None + readable: list[sock] + readable, writable, exceptional = select.select([self.server], [], [], None) + + if not self.is_socket_open(self.server): + self.__logger.info("Socket server queue closed, stopping...") + break + + if self.server in readable: + try: + client_socket, client_address = self.server.accept() + client_socket.setblocking(False) + client_ip, client_port = client_address + + response_to_send = await self.__handle_client(client_socket, client_ip, client_port) + if response_to_send: + response_json = SocketCommon.serialize( + response_to_send.to_json(SocketCommon.serialize) + ) + # Send the JSON response back to the client + # self.__logger.debug("[%s:%d] <- %s", client_ip, client_port, response_json) + self.__common.send_chunk(client_socket, response_json) + except Exception as err: + if isinstance(err, OSError) and err.errno == 9: self.__logger.warning("Socket: [Errno 9] Bad file descriptor") - continue - if client_ip is not None and client_port is not None: - self.__logger.warning("[%s:%d] %s", client_ip, client_port, err, exc_info=True) - finally: - if client_socket is not None: - client_socket.close() - client_socket = None - client_address = None - client_ip = None - client_port = None + elif client_ip is not None and client_port is not None: + self.__logger.warning("[%s:%d] %s", client_ip, client_port, err, exc_info=True) + else: + raise err + finally: + if client_socket is not None: + client_socket.close() self.__logger.info("Socket server closed") def close(self): self.is_running = False - if self.server_socket is None: + if not self.server or not self.is_socket_open(self.server): + self.__logger.warning("Socket server already stopped") return - # self.server_socket.shutdown(socket.SHUT_RDWR) - self.server_socket.close() - self.server_socket = None + + try: + self.server.shutdown(socket.SHUT_RDWR) + except OSError as err: + if err.errno != 9: + self.__logger.warning("Error during socket shutdown: %s", err) + else: + raise err + self.server.close() self.__logger.info("Socket server stop") async def __handle_client(self, client_socket: sock, client_ip, client_port) -> SocketResponse | None: @@ -111,3 +124,11 @@ def object_hook(json): "[%s:%d] <- Unknown action '%s'", client_ip, client_port, json_data.action ) return response + + @classmethod + def is_socket_open(cls, sock: sock): + try: + sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + return True + except socket.error: + return False From 2b16a42093b44be3097220d4c3a52072a55375d3 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Tue, 8 Oct 2024 02:00:32 +0200 Subject: [PATCH 23/32] chore: clean previous commits --- src/pyramid/api/tasks/tools/parameters.py | 3 ++- src/pyramid/api/tasks/tools/register.py | 7 +++---- src/pyramid/services/todo.yaml | 5 ----- src/pyramid/tasks/discord.py | 2 +- 4 files changed, 6 insertions(+), 11 deletions(-) delete mode 100644 src/pyramid/services/todo.yaml diff --git a/src/pyramid/api/tasks/tools/parameters.py b/src/pyramid/api/tasks/tools/parameters.py index e732421..74162ba 100644 --- a/src/pyramid/api/tasks/tools/parameters.py +++ b/src/pyramid/api/tasks/tools/parameters.py @@ -6,7 +6,8 @@ class ParametersTask: - def __init__(self): + def __init__(self, stop_own_loop = False): self.loop: asyncio.AbstractEventLoop self.thread: Thread self.cls_instance: TaskInjector + self.stop_own_loop = stop_own_loop diff --git a/src/pyramid/api/tasks/tools/register.py b/src/pyramid/api/tasks/tools/register.py index 2230207..1319f61 100644 --- a/src/pyramid/api/tasks/tools/register.py +++ b/src/pyramid/api/tasks/tools/register.py @@ -7,6 +7,7 @@ import signal from threading import Thread from typing import TypeVar + from pyramid.api.services.tools.register import ServiceRegister from pyramid.api.tasks.tools.injector import TaskInjector from pyramid.api.tasks.tools.parameters import ParametersTask @@ -62,10 +63,10 @@ async def shutdown(loop: asyncio.AbstractEventLoop): if loop.is_closed(): logging.warning("Loop of task %s is already closed." % name) return - logging.info("Task %s stopping..." % name) + logging.info("Task %s ask to stopping..." % name) await parameters.cls_instance.stop_asyc() loop.stop() - if (name == "DiscordTask"): + if parameters.stop_own_loop: asyncio.run_coroutine_threadsafe(shutdown(parameters.loop), parameters.loop) else: result = loop.run_until_complete(shutdown(parameters.loop)) @@ -93,9 +94,7 @@ def running(loop: asyncio.AbstractEventLoop): parameters.thread.start() for name, parameters in cls.__TASKS_REGISTERED.items(): - logging.info("JOIN %s" % name) parameters.thread.join() - logging.info("STOP JOIN %s" % name) signal.signal(signal.SIGTERM, previous_handler) logging.info("All registered tasks are stopped") diff --git a/src/pyramid/services/todo.yaml b/src/pyramid/services/todo.yaml deleted file mode 100644 index 6f8e076..0000000 --- a/src/pyramid/services/todo.yaml +++ /dev/null @@ -1,5 +0,0 @@ -# TODO -Verify stop of tasks -- Message Queue Task should be ok -- Discord is ok but his stop is not like other. Add it to params to handle it properly -- Socket server sometimes never stop. Look into it \ No newline at end of file diff --git a/src/pyramid/tasks/discord.py b/src/pyramid/tasks/discord.py index 15c4a3a..62b2f04 100644 --- a/src/pyramid/tasks/discord.py +++ b/src/pyramid/tasks/discord.py @@ -3,7 +3,7 @@ from pyramid.api.tasks.tools.injector import TaskInjector from pyramid.api.tasks.tools.parameters import ParametersTask -@pyramid_task(parameters=ParametersTask()) +@pyramid_task(parameters=ParametersTask(stop_own_loop = True)) class DiscordTask(TaskInjector): def injectService(self, discord_service: IDiscordService): From 4d32719ad6d1dafd72c164834087d10f3adea687 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Mon, 18 Nov 2024 23:16:56 +0100 Subject: [PATCH 24/32] fix: IDeezerDownloaderService init, and add deezer token checker in tests --- src/pyramid/api/services/deezer_downloader.py | 6 ++-- src/pyramid/services/deezer_downloader.py | 8 +++--- tests/deezer_token_test.py | 28 +++++++++++++++++++ 3 files changed, 35 insertions(+), 7 deletions(-) create mode 100644 tests/deezer_token_test.py diff --git a/src/pyramid/api/services/deezer_downloader.py b/src/pyramid/api/services/deezer_downloader.py index 9e9b1c6..dddfb0d 100644 --- a/src/pyramid/api/services/deezer_downloader.py +++ b/src/pyramid/api/services/deezer_downloader.py @@ -1,13 +1,13 @@ from abc import ABC, abstractmethod -from typing import Any +from pyramid.connector.deezer.download.client import PyDeezer from pyramid.data.music.track import Track class IDeezerDownloaderService(ABC): @abstractmethod - async def check_credentials(self) -> dict[str, Any]: + async def dl_track_by_id(self, track_id) -> Track | None: pass @abstractmethod - async def dl_track_by_id(self, track_id) -> Track | None: + async def get_client(self) -> PyDeezer: pass diff --git a/src/pyramid/services/deezer_downloader.py b/src/pyramid/services/deezer_downloader.py index 38e4548..377bc24 100644 --- a/src/pyramid/services/deezer_downloader.py +++ b/src/pyramid/services/deezer_downloader.py @@ -40,7 +40,7 @@ def start(self): self.music_format = track_formats.MP3_128 async def dl_track_by_id(self, track_id) -> Track | None: - client = await self._get_client() + client = await self.get_client() # try: track_info = await client.get_track_info(track_id) # except APIRequestError as err: @@ -66,7 +66,7 @@ async def dl_track_by_id(self, track_id) -> Track | None: async def __dl_track(self, track_info, file_name: str) -> bool: try: - client = await self._get_client() + client = await self.get_client() await client.download_track( track_info, self.__configuration_service.deezer__folder, @@ -96,8 +96,8 @@ async def __dl_track(self, track_info, file_name: str) -> bool: track = Track(track_info, None) self.__logger.warning("Unable to dl track %s", track, exc_info=True) return False - - async def _get_client(self) -> PyDeezer: + + async def get_client(self) -> PyDeezer: return await self._define_client() async def _define_client(self) -> PyDeezer: diff --git a/tests/deezer_token_test.py b/tests/deezer_token_test.py new file mode 100644 index 0000000..302479c --- /dev/null +++ b/tests/deezer_token_test.py @@ -0,0 +1,28 @@ +import os +import unittest + +from pyramid.api.services.configuration import IConfigurationService +from pyramid.api.services.deezer_downloader import IDeezerDownloaderService +from pyramid.api.services.tools.tester import ServiceStandalone +from pyramid.data.exceptions import DeezerTokenInvalidException, DeezerTokensUnavailableException +from pyramid.services.builder.configuration import ConfigurationBuilder + +class DeezerTokenTest(unittest.IsolatedAsyncioTestCase): + def __init__(self, methodName: str = "runTest") -> None: + super().__init__(methodName) + + ServiceStandalone.import_services() + builder = ConfigurationBuilder().deezer_arl(os.getenv("DEEZER__ARL") or "") + ServiceStandalone.set_service(IConfigurationService, builder.build()) + + self.deezer_downloader = ServiceStandalone.get_service(IDeezerDownloaderService) + + async def test_download(self): + try: + await self.deezer_downloader.get_client() + except (DeezerTokenInvalidException, DeezerTokensUnavailableException) as err: + self.fail(err.args) + + +if __name__ == "__main__": + unittest.main(failfast=True) From 9b15da267b35bde7c554e0d80719645def84c54d Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Tue, 19 Nov 2024 00:37:21 +0100 Subject: [PATCH 25/32] fix: err.args -> err.msg --- tests/deezer_token_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deezer_token_test.py b/tests/deezer_token_test.py index 302479c..2b67716 100644 --- a/tests/deezer_token_test.py +++ b/tests/deezer_token_test.py @@ -21,7 +21,7 @@ async def test_download(self): try: await self.deezer_downloader.get_client() except (DeezerTokenInvalidException, DeezerTokensUnavailableException) as err: - self.fail(err.args) + self.fail(err.msg) if __name__ == "__main__": From ba78c1b050449b99cd2d23f459f6c36e93dd55c7 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Fri, 6 Dec 2024 00:05:59 +0100 Subject: [PATCH 26/32] debug: socket server --- src/pyramid/services/socket_server.py | 80 +++++++++++++++++++-------- 1 file changed, 56 insertions(+), 24 deletions(-) diff --git a/src/pyramid/services/socket_server.py b/src/pyramid/services/socket_server.py index bc8f5e3..ae036cf 100644 --- a/src/pyramid/services/socket_server.py +++ b/src/pyramid/services/socket_server.py @@ -21,6 +21,7 @@ def __init__(self) -> None: self.__port = self.__common.port self.is_running = False self.server: sock | None = None + self.connected_clients: set[sock] = set() def injectService(self, logger_service: ILoggerService @@ -28,16 +29,22 @@ def injectService(self, self.__logger = logger_service async def open(self): + if self.is_running: + self.__logger.warning("Socket server is already running.") + return self.server = sock(socket.AF_INET, socket.SOCK_STREAM) self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.server.bind((self.__host, self.__port)) self.server.listen(10) self.server.setblocking(False) + self.is_running = True self.__logger.info("Socket server open on %s:%d", self.__host, self.__port) + await self.__run() - self.is_running = True - + async def __run(self): + if not self.server: + return while self.is_running: client_socket: sock | None = None client_address: Any = None @@ -46,7 +53,10 @@ async def open(self): response_to_send: SocketResponse | None = None response_json: str | None = None readable: list[sock] + # readable, writable, exceptional = select.select([self.server, *self.connected_clients], [], [], None) + self.__logger.debug("server.select") readable, writable, exceptional = select.select([self.server], [], [], None) + self.__logger.debug("server.selected") if not self.is_socket_open(self.server): self.__logger.info("Socket server queue closed, stopping...") @@ -54,18 +64,15 @@ async def open(self): if self.server in readable: try: + self.__logger.debug("server.accept") client_socket, client_address = self.server.accept() + self.__logger.debug("server.accepted") client_socket.setblocking(False) client_ip, client_port = client_address + self.connected_clients.add(client_socket) + self.__logger.debug("Client %s:%d connected", *client_address) - response_to_send = await self.__handle_client(client_socket, client_ip, client_port) - if response_to_send: - response_json = SocketCommon.serialize( - response_to_send.to_json(SocketCommon.serialize) - ) - # Send the JSON response back to the client - # self.__logger.debug("[%s:%d] <- %s", client_ip, client_port, response_json) - self.__common.send_chunk(client_socket, response_json) + await self.__handle_client(client_socket, client_ip, client_port) except Exception as err: if isinstance(err, OSError) and err.errno == 9: self.__logger.warning("Socket: [Errno 9] Bad file descriptor") @@ -75,6 +82,8 @@ async def open(self): raise err finally: if client_socket is not None: + self.__logger.debug("Client %s:%d disconnected ", *client_address) + self.connected_clients.discard(client_socket) client_socket.close() self.__logger.info("Socket server closed") @@ -87,20 +96,45 @@ def close(self): try: self.server.shutdown(socket.SHUT_RDWR) except OSError as err: - if err.errno != 9: - self.__logger.warning("Error during socket shutdown: %s", err) + if err.errno == 9: + self.__logger.warning("Socket server already closed (Errno 9).") else: - raise err + self.__logger.warning("Error during socket shutdown: %s", err) + return self.server.close() - self.__logger.info("Socket server stop") - - async def __handle_client(self, client_socket: sock, client_ip, client_port) -> SocketResponse | None: - data = self.__common.receive_chunk(client_socket) - if not data: - self.__logger.info("[%s:%d] -> ", client_ip, client_port) - return + for client in list(self.connected_clients): + try: + client.close() + except Exception: + pass + self.connected_clients.clear() + self.__logger.info("Socket server stopped") + async def __handle_client(self, client_socket: sock, client_ip, client_port): + try: + data = self.__common.receive_chunk(client_socket) + + if not data: + self.__logger.info("[%s:%d] -> ", client_ip, client_port) + self.connected_clients.remove(client_socket) + client_socket.close() + return + response = await self.__process_request(data, client_ip, client_port) + if response: + response_json = SocketCommon.serialize( + response.to_json(SocketCommon.serialize) + ) + # Send the JSON response back to the client + # self.__logger.debug("[%s:%d] <- %s", client_ip, client_port, response_json) + self.__common.send_chunk(client_socket, response_json) + + except Exception as err: + self.__logger.warning("Error handling client: %s", err) + self.connected_clients.discard(client_socket) + client_socket.close() + + async def __process_request(self, data: str, client_ip, client_port) -> SocketResponse | None: def object_hook(json): if isinstance(json, dict): return AskRequest(**json) @@ -115,15 +149,13 @@ def object_hook(json): return response if json_data.action == "health": - data = PingSocket(True) - response.create(ResponseCode.OK, None, data) + pingSocket = PingSocket(True) + response.create(ResponseCode.OK, None, pingSocket) return response - response.create(ResponseCode.ERROR, "Unknown action") self.__logger.info( "[%s:%d] <- Unknown action '%s'", client_ip, client_port, json_data.action ) - return response @classmethod def is_socket_open(cls, sock: sock): From d0c0685a541375be9e79bc61d339eb5e8edbebc1 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Fri, 6 Dec 2024 00:29:09 +0100 Subject: [PATCH 27/32] fix: tests u --- tests/spotify_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/spotify_test.py b/tests/spotify_test.py index aa36492..391aef8 100644 --- a/tests/spotify_test.py +++ b/tests/spotify_test.py @@ -48,25 +48,25 @@ async def test_search_top(self): async def test_url_artist(self): tracks = await self.spotify_search.get_by_url( - "https://open.spotify.com/intl-fr/artist/2HALYSe657tNJ1iKVXP2xA?si=e5x_arTGSqWtnRrJjCCTdQ" + "https://open.spotify.com/intl-fr/artist/2HALYSe657tNJ1iKVXP2xA" ) self.assertIsNotNone(tracks) async def test_url_album(self): tracks = await self.spotify_search.get_by_url( - "https://open.spotify.com/album/1mhVZVEHbIYUN7b5DkXF7d?si=U9HaPDJtRdSnT9qZaQxjXA" + "https://open.spotify.com/album/1mhVZVEHbIYUN7b5DkXF7d" ) self.assertIsNotNone(tracks) async def test_url_track(self): tracks = await self.spotify_search.get_by_url( - "https://open.spotify.com/intl-fr/track/1mzZP8UA2RZUXDw33QNmn4?si=4c27f893cd7c4055" + "https://open.spotify.com/intl-fr/track/1mzZP8UA2RZUXDw33QNmn4" ) self.assertIsNotNone(tracks) async def test_url_playlist(self): tracks = await self.spotify_search.get_by_url( - "https://open.spotify.com/playlist/37i9dQZF1DZ06evO1ymAtQ?si=5c4b6ff36a0d4c98" + "https://open.spotify.com/playlist/37i9dQZF1DZ06evO1ymAtQ" ) self.assertIsNotNone(tracks) From 83a3a489fdce7e3974e022677607c9e856118927 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Fri, 6 Dec 2024 00:32:42 +0100 Subject: [PATCH 28/32] fix: tests u --- tests/spotify_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spotify_test.py b/tests/spotify_test.py index 391aef8..6ac4d26 100644 --- a/tests/spotify_test.py +++ b/tests/spotify_test.py @@ -66,7 +66,7 @@ async def test_url_track(self): async def test_url_playlist(self): tracks = await self.spotify_search.get_by_url( - "https://open.spotify.com/playlist/37i9dQZF1DZ06evO1ymAtQ" + "https://open.spotify.com/playlist/2CEwnRwSNdIvPn7u6E3whG" ) self.assertIsNotNone(tracks) From f0fb7be52524f4b167ec2ec09fda6fe97deaa40e Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Fri, 6 Dec 2024 04:25:52 +0100 Subject: [PATCH 29/32] feat: discord output add emojis and rework regex for source url --- src/pyramid/connector/deezer/tools.py | 36 --------- src/pyramid/connector/discord/guild_cmd.py | 48 ++++++------ .../connector/spotify/spotify_tools.py | 30 -------- src/pyramid/data/a_engine_tools.py | 8 -- src/pyramid/data/exceptions.py | 7 +- src/pyramid/services/deezer_search.py | 75 ++++++++++++++----- src/pyramid/services/source.py | 2 +- src/pyramid/services/spotify_client.py | 17 +++-- src/pyramid/services/spotify_search.py | 56 ++++++++++---- 9 files changed, 139 insertions(+), 140 deletions(-) delete mode 100644 src/pyramid/connector/deezer/tools.py delete mode 100644 src/pyramid/connector/spotify/spotify_tools.py delete mode 100644 src/pyramid/data/a_engine_tools.py diff --git a/src/pyramid/connector/deezer/tools.py b/src/pyramid/connector/deezer/tools.py deleted file mode 100644 index 9310015..0000000 --- a/src/pyramid/connector/deezer/tools.py +++ /dev/null @@ -1,36 +0,0 @@ - -import re -import aiohttp -from pyramid.data.a_engine_tools import AEngineTools -from pyramid.services.deezer_search import DeezerType - - -class DeezerTools(AEngineTools): - - @classmethod - async def extract_from_url(cls, url) -> tuple[int, DeezerType | None] | tuple[None, None]: - # Resolve if URL is a deezer.page.link URL - if "deezer.page.link" in url: - async with aiohttp.ClientSession() as session: - async with session.get(url, allow_redirects=True) as response: - url = str(response.url) - - # Extract ID and type using regex - pattern = r"(?<=deezer.com/fr/)(\w+)/(?P\d+)" - match = re.search(pattern, url) - if not match: - return None, None - deezer_type_str = match.group(1).upper() - if deezer_type_str == "PLAYLIST": - deezer_type = DeezerType.PLAYLIST - elif deezer_type_str == "ARTIST": - deezer_type = DeezerType.ARTIST - elif deezer_type_str == "ALBUM": - deezer_type = DeezerType.ALBUM - elif deezer_type_str == "TRACK": - deezer_type = DeezerType.TRACK - else: - deezer_type = None - - deezer_id = int(match.group("id")) - return deezer_id, deezer_type diff --git a/src/pyramid/connector/discord/guild_cmd.py b/src/pyramid/connector/discord/guild_cmd.py index 0875b91..f3c0284 100644 --- a/src/pyramid/connector/discord/guild_cmd.py +++ b/src/pyramid/connector/discord/guild_cmd.py @@ -41,7 +41,7 @@ async def play( if not voice_channel: return False - ms.edit_message(f"Searching **{input}**", "search") + ms.edit_message(f"💿 Searching **{input}**", "search") try: track: TrackMinimal | None = await self.data.search_engine.search_track(input, engine) @@ -58,10 +58,10 @@ async def stop(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: self.data.track_list.clear() if await self.queue.exit() is False: - ms.add_message("The bot does not currently play music") + ms.add_message("❌ The bot does not currently play music.") return False - ms.add_message("Music stop") + ms.add_message("✔️ Music stop") return True async def pause(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: @@ -70,10 +70,10 @@ async def pause(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: return False if self.queue.pause() is False: - ms.add_message("The bot does not currently play music") + ms.add_message("❌ The bot does not currently play music.") return False - ms.add_message("Music paused") + ms.add_message("✔️ Music paused") return True async def resume(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: @@ -82,10 +82,10 @@ async def resume(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: return False if self.queue.resume() is False: - ms.add_message("The bot is not currently paused") + ms.add_message("❌ The bot is not currently paused.") return False - ms.add_message("Music resume") + ms.add_message("✔️ Music resume") return True async def resume_or_pause(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: @@ -106,18 +106,18 @@ async def next(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: if self.queue.has_next() is False: if self.queue.stop() is False: - ms.add_message("The bot does not currently play music") + ms.add_message("❌ The bot does not currently play music.") return False else: - ms.add_message("The bot didn't have next music") + ms.add_message("❌ The bot didn't have next music.") return True await self.queue.goto_channel(voice_channel) if self.queue.next() is False: - ms.add_message("Unable to play next music") + ms.add_message("❌ Unable to play next music.") return False - ms.add_message("Skip musique") + ms.add_message("✔️ Skip musique.") return True async def shuffle(self, ms: MessageSenderQueued, ctx: Interaction): @@ -126,10 +126,10 @@ async def shuffle(self, ms: MessageSenderQueued, ctx: Interaction): return False if not self.queue.shuffle(): - ms.add_message("No need to shuffle the queue.") + ms.add_message("❌ No need to shuffle the queue.") return False - ms.add_message("The queue has been shuffled.") + ms.add_message("✔️ The queue has been shuffled.") return True async def remove(self, ms: MessageSenderQueued, ctx: Interaction, number_in_queue: int): @@ -139,20 +139,20 @@ async def remove(self, ms: MessageSenderQueued, ctx: Interaction, number_in_queu if number_in_queue <= 0: ms.add_message( - content=f"Unable to remove element with the number {number_in_queue} in the queue" + content=f"❌ Unable to remove element with the number {number_in_queue} in the queue." ) return False if number_in_queue == 1: ms.add_message( - content="Unable to remove the current track from the queue. Use `next` instead" + content="❌ Unable to remove the current track from the queue. Use `next` instead." ) return False track_deleted = self.queue.remove(number_in_queue - 1) if track_deleted is None: ms.add_message( - content=f"There is no element with the number {number_in_queue} in the queue" + content=f"❌ There is no element with the number {number_in_queue} in the queue." ) return False @@ -166,31 +166,31 @@ async def goto(self, ms: MessageSenderQueued, ctx: Interaction, number_in_queue: if number_in_queue <= 0: ms.add_message( - content=f"Unable to go to element with number {number_in_queue} in the queue" + content=f"❌ Unable to go to element with number {number_in_queue} in the queue" ) return False if number_in_queue == 1: ms.add_message( - content="Unable to remove the current track from the queue. Use `next` instead" + content="❌ Unable to remove the current track from the queue. Use `next` instead." ) return False tracks_removed = self.queue.goto(number_in_queue - 1) if tracks_removed <= 0: ms.add_message( - content=f"There is no element with the number {number_in_queue} in the queue" + content=f"❌ There is no element with the number {number_in_queue} in the queue." ) return False # +1 for current track - ms.add_message(f"f{tracks_removed + 1} tracks has been skipped") + ms.add_message(f"✔️ {tracks_removed + 1} tracks has been skipped.") return True def queue_list(self, ms: MessageSenderQueued, ctx: Interaction) -> bool: queue: str | None = self.queue.queue_list() if queue is None: - ms.add_message("Queue is empty") + ms.add_message("❌ Queue is empty") return False ms.add_code_message(queue, prefix="Here's the music in the queue :") @@ -227,7 +227,7 @@ async def callback(user: User | Member, ms: MessageSenderQueued, t: TrackMinimal if tracks_unfindable: hsa = utils_list_track.to_str(tracks_unfindable) ms.add_code_message(hsa, prefix=":warning: Can't find the audio for these tracks :") - await ms.ctx.followup.send(content="Choose a title from the provided list :", view=view) + await ms.ctx.followup.send(content="📋 Choose a title from the provided list :", view=view) return True async def play_url( @@ -237,7 +237,7 @@ async def play_url( if not voice_channel: return False - ms.edit_message(f"Searching **{url}** ...", "search") + ms.edit_message(f"💿 Searching **{url}** ...", "search") try: result = await self.data.search_engine.search_by_url(url) @@ -253,4 +253,4 @@ async def play_url( tracks = result return await self._execute_play(ms, voice_channel, tracks, at_end=at_end) else: - raise Exception("Unknown type 'res'") + raise Exception("Unknown type 'result'") diff --git a/src/pyramid/connector/spotify/spotify_tools.py b/src/pyramid/connector/spotify/spotify_tools.py deleted file mode 100644 index 53b013b..0000000 --- a/src/pyramid/connector/spotify/spotify_tools.py +++ /dev/null @@ -1,30 +0,0 @@ - - -import re -from pyramid.connector.spotify.spotify_type import SpotifyType -from pyramid.data.a_engine_tools import AEngineTools - - -class SpotifyTools(AEngineTools): - - @classmethod - def extract_from_url(cls, url) -> tuple[str, SpotifyType | None] | tuple[None, None]: - # Extract ID and type using regex - pattern = r"(?<=open\.spotify\.com/)(intl-(?P\w+)/)?(?P\w+)/(?P\w+)" - match = re.search(pattern, url) - if not match: - return None, None - type_str = match.group("type").upper() - if type_str == "PLAYLIST": - type = SpotifyType.PLAYLIST - elif type_str == "ARTIST": - type = SpotifyType.ARTIST - elif type_str == "ALBUM": - type = SpotifyType.ALBUM - elif type_str == "TRACK": - type = SpotifyType.TRACK - else: - type = None - - id = match.group("id") - return id, type diff --git a/src/pyramid/data/a_engine_tools.py b/src/pyramid/data/a_engine_tools.py deleted file mode 100644 index 5daf380..0000000 --- a/src/pyramid/data/a_engine_tools.py +++ /dev/null @@ -1,8 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - - -class AEngineTools(ABC): - @abstractmethod - def extract_from_url(self, url) -> tuple[int | str, Any | None] | tuple[None, None]: - pass diff --git a/src/pyramid/data/exceptions.py b/src/pyramid/data/exceptions.py index eb7fd68..2bf66bf 100644 --- a/src/pyramid/data/exceptions.py +++ b/src/pyramid/data/exceptions.py @@ -12,7 +12,6 @@ class DiscordMessageException(CustomException): class EngineSourceNotFoundException(DiscordMessageException): pass - class TrackNotFoundException(DiscordMessageException): pass @@ -27,3 +26,9 @@ class DeezerTokensUnavailableException(DeezerTokenException): class DeezerTokenOverflowException(DeezerTokenException): pass + +class RessourceNotExistsException(DiscordMessageException): + pass + +class RessourceBadFormatException(DiscordMessageException): + pass diff --git a/src/pyramid/services/deezer_search.py b/src/pyramid/services/deezer_search.py index 7383a7f..69c2777 100644 --- a/src/pyramid/services/deezer_search.py +++ b/src/pyramid/services/deezer_search.py @@ -1,17 +1,19 @@ import asyncio import logging +import re +import aiohttp import deezer from pyramid.api.services.configuration import IConfigurationService from pyramid.api.services.deezer_client import IDeezerClientService from pyramid.api.services.deezer_search import IDeezerSearchService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector -from pyramid.connector.deezer.client.exceptions import CliDeezerNoDataException, CliDeezerRateLimitError +from pyramid.connector.deezer.client.exceptions import CliDeezerErrorResponse, CliDeezerNoDataException, CliDeezerRateLimitError from pyramid.connector.deezer.client.list_paginated import DeezerListPaginated from pyramid.connector.deezer.deezer_type import DeezerType -from pyramid.connector.deezer.tools import DeezerTools from pyramid.data.a_search import ASearch, ASearchId +from pyramid.data.exceptions import RessourceBadFormatException, RessourceNotExistsException from pyramid.data.music.track_minimal_deezer import TrackMinimalDeezer @@ -151,31 +153,35 @@ async def get_top_artist_by_id( async def get_by_url( self, url ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | TrackMinimalDeezer | None: - id, type = await DeezerTools.extract_from_url(url) + id, type = await self.extract_from_url(url) if id is None: return None if type is None: - raise NotImplementedError(f"The type of deezer info '{url}' is not implemented") + raise RessourceBadFormatException("❌ Deezer **%s** is not recognized.", url) tracks: ( tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | TrackMinimalDeezer | None ) - - if type == DeezerType.PLAYLIST: - # future = asyncio.get_event_loop().run_in_executor( - # None, self.get_playlist_tracks_by_id, id - # ) - # tracks = await asyncio.wrap_future(future) - tracks = await self.get_playlist_tracks_by_id(id) - elif type == DeezerType.ARTIST: - tracks = await self.get_top_artist_by_id(id) - elif type == DeezerType.ALBUM: - tracks = await self.get_album_tracks_by_id(id) - elif type == DeezerType.TRACK: - tracks = await self.get_track_by_id(id) - else: - raise NotImplementedError(f"The type of deezer info '{type}' can't be resolve") + try: + if type == DeezerType.PLAYLIST: + # future = asyncio.get_event_loop().run_in_executor( + # None, self.get_playlist_tracks_by_id, id + # ) + # tracks = await asyncio.wrap_future(future) + tracks = await self.get_playlist_tracks_by_id(id) + elif type == DeezerType.ARTIST: + tracks = await self.get_top_artist_by_id(id) + elif type == DeezerType.ALBUM: + tracks = await self.get_album_tracks_by_id(id) + elif type == DeezerType.TRACK: + tracks = await self.get_track_by_id(id) + else: + raise RessourceBadFormatException("❌ Deezer **%s** is not fully implemented. Try later.", type.name.lower()) + # except CliDeezerErrorResponse as e: + # raise RessourceNotExistsException("❌ Deezer **%s** is not accessible.", url) + except CliDeezerNoDataException as e: + raise RessourceBadFormatException("❌ Deezer **%s** is a wrong URL format.", url) return tracks @@ -254,3 +260,34 @@ def __remove_special_chars( return "".join(result) + @classmethod + async def extract_from_url(cls, url) -> tuple[int, DeezerType | None] | tuple[None, None]: + + # Resolve if URL is a deezer.page.link URL + pattern = r"^(https?:\/\/)?deezer\.page\.link" + if re.match(pattern, url): + async with aiohttp.ClientSession() as session: + async with session.get(url, allow_redirects=True) as response: + url = str(response.url) + if not url: + raise RessourceBadFormatException("❌ Deezer shortcut **%s** is a wrong URL format.", url) + + # Extract ID and type using regex + pattern = r"^(?:https?:\/\/)?(?:www\.)?deezer\.com\/(?:\w{2}\/)?(?P\w+)\/(?P\d+)" + match = re.search(pattern, url) + if not match: + return None, None + deezer_type_str = match.group("type").upper() + if deezer_type_str == "PLAYLIST": + deezer_type = DeezerType.PLAYLIST + elif deezer_type_str == "ARTIST": + deezer_type = DeezerType.ARTIST + elif deezer_type_str == "ALBUM": + deezer_type = DeezerType.ALBUM + elif deezer_type_str == "TRACK": + deezer_type = DeezerType.TRACK + else: + deezer_type = None + + deezer_id = int(match.group("id")) + return deezer_id, deezer_type diff --git a/src/pyramid/services/source.py b/src/pyramid/services/source.py index 6a94230..e076717 100644 --- a/src/pyramid/services/source.py +++ b/src/pyramid/services/source.py @@ -63,7 +63,7 @@ async def search_by_url(self, url: str): break if not result: - raise TrackNotFoundException("URL **%s** not found.", url) + raise TrackNotFoundException("❌ URL **%s** not found. Only Deezer and Spotify are supported.", url) if isinstance(result, tuple): tracks, tracks_unfindable = result diff --git a/src/pyramid/services/spotify_client.py b/src/pyramid/services/spotify_client.py index e089c4c..9f9fc9f 100644 --- a/src/pyramid/services/spotify_client.py +++ b/src/pyramid/services/spotify_client.py @@ -133,14 +133,15 @@ async def _async_internal_call(self, method: str, url: str, payload, params) -> msg = await response.text() or None reason = None - logger.error( - "HTTP Error for %s to %s with Params: %s returned %s due to %s", - method, - url, - args.get("params"), - response.status, - msg, - ) + if response.status != 404 and response.status != 400: + logger.error( + "HTTP Error for %s to %s with Params: %s returned %s due to %s", + method, + url, + args.get("params"), + response.status, + msg, + ) raise SpotifyException( response.status, -1, diff --git a/src/pyramid/services/spotify_search.py b/src/pyramid/services/spotify_search.py index 0da0c9d..e763edd 100644 --- a/src/pyramid/services/spotify_search.py +++ b/src/pyramid/services/spotify_search.py @@ -1,3 +1,4 @@ +import re from pyramid.api.services.configuration import IConfigurationService from pyramid.api.services.spotify_client import ISpotifyClientService from pyramid.api.services.spotify_search import ISpotifySearchService @@ -5,9 +6,10 @@ from pyramid.api.services.spotify_search_id import ISpotifySearchIdService from pyramid.api.services.tools.annotation import pyramid_service from pyramid.api.services.tools.injector import ServiceInjector -from pyramid.connector.spotify.spotify_tools import SpotifyTools from pyramid.connector.spotify.spotify_type import SpotifyType +from pyramid.data.exceptions import DiscordMessageException, RessourceBadFormatException, RessourceNotExistsException from pyramid.data.music.track_minimal_spotify import TrackMinimalSpotify +from spotipy.exceptions import SpotifyException @pyramid_service(interface=ISpotifySearchService) @@ -91,26 +93,54 @@ async def get_top_artist( async def get_by_url( self, url ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | TrackMinimalSpotify | None: - id, type = SpotifyTools.extract_from_url(url) + id, type = self.extract_from_url(url) if id is None: return None if type is None: - raise NotImplementedError(f"The type of spotify info '{url}' is not implemented") + raise RessourceBadFormatException("❌ Spotify **%s** is not recognized.", url) tracks: ( tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | TrackMinimalSpotify | None ) + try: + if type == SpotifyType.PLAYLIST: + tracks = await self.__spotify_search_id.get_playlist_tracks_by_id(id) + elif type == SpotifyType.ARTIST: + tracks = await self.__spotify_search_id.get_top_artist_by_id(id) + elif type == SpotifyType.ALBUM: + tracks = await self.__spotify_search_id.get_album_tracks_by_id(id) + elif type == SpotifyType.TRACK: + tracks = await self.__spotify_search_id.get_track_by_id(id) + else: + raise RessourceBadFormatException("❌ Spotify **%s** is not fully implemented. Try later.", type.name.lower()) + + except SpotifyException as err: + if err.http_status == 400: + raise RessourceBadFormatException("❌ Spotify **%s** is a wrong URL format.", url) + elif err.http_status == 404: + raise RessourceNotExistsException("❌ Spotify **%s** is not accessible.", url) + else: + raise err + return tracks - if type == SpotifyType.PLAYLIST: - tracks = await self.__spotify_search_id.get_playlist_tracks_by_id(id) - elif type == SpotifyType.ARTIST: - tracks = await self.__spotify_search_id.get_top_artist_by_id(id) - elif type == SpotifyType.ALBUM: - tracks = await self.__spotify_search_id.get_album_tracks_by_id(id) - elif type == SpotifyType.TRACK: - tracks = await self.__spotify_search_id.get_track_by_id(id) + @classmethod + def extract_from_url(cls, url) -> tuple[str, SpotifyType | None] | tuple[None, None]: + pattern = r"^(?:https?:\/\/)?(?:www\.)?open\.spotify\.com\/(?:\w{2}\/)?(?P\w+)\/(?P\w+)" + match = re.search(pattern, url) + if not match: + return None, None + type_str = match.group("type").upper() + if type_str == "PLAYLIST": + type = SpotifyType.PLAYLIST + elif type_str == "ARTIST": + type = SpotifyType.ARTIST + elif type_str == "ALBUM": + type = SpotifyType.ALBUM + elif type_str == "TRACK": + type = SpotifyType.TRACK else: - raise NotImplementedError(f"The type of spotify info '{type}' can't be resolve") + type = None - return tracks + id = match.group("id") + return id, type From 67b7c41cd738ede916111a9998f3b4a26f36fc64 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Fri, 6 Dec 2024 04:30:54 +0100 Subject: [PATCH 30/32] fix: ci-rule --- .github/workflows/python.yml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index cdaefbf..7f2a959 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -7,10 +7,11 @@ on: tags: - '[0-9]+.[0-9]+.[0-9]+' paths: - - "src/**/*.py" + - "src/**" + - "tests/**" - "requirements.txt" - ".dockerignore" - - "./.docker/Dockerfile" + - ".docker/Dockerfile" - "docker-compose*.yml" - ".github/workflows/python.yml" pull_request: @@ -18,11 +19,11 @@ on: branches: - "*" paths: - - "src/**/*.py" - - "tests/**/*.py" + - "src/**" + - "tests/**" - "requirements.txt" - ".dockerignore" - - "./.docker/Dockerfile" + - ".docker/Dockerfile" - "docker-compose*.yml" - ".github/workflows/python.yml" workflow_dispatch: From 10011b550d6e17e570ca78f8cea9db5640ba9899 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Fri, 6 Dec 2024 04:36:20 +0100 Subject: [PATCH 31/32] fix: spotify regex --- src/pyramid/services/spotify_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyramid/services/spotify_search.py b/src/pyramid/services/spotify_search.py index e763edd..647371d 100644 --- a/src/pyramid/services/spotify_search.py +++ b/src/pyramid/services/spotify_search.py @@ -126,7 +126,7 @@ async def get_by_url( @classmethod def extract_from_url(cls, url) -> tuple[str, SpotifyType | None] | tuple[None, None]: - pattern = r"^(?:https?:\/\/)?(?:www\.)?open\.spotify\.com\/(?:\w{2}\/)?(?P\w+)\/(?P\w+)" + pattern = r"^(?:https?:\/\/)?(?:www\.)?open\.spotify\.com\/(?:intl-\w+\/)?(?:\w+\/)?(?P\w+)\/(?P\w+)" match = re.search(pattern, url) if not match: return None, None From fbaf5b6eaccfd0c01618600cc3f3b044394ba804 Mon Sep 17 00:00:00 2001 From: Tristiisch Date: Fri, 6 Dec 2024 04:37:35 +0100 Subject: [PATCH 32/32] feat: add deezer tests --- tests/deezer_search_test.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/deezer_search_test.py b/tests/deezer_search_test.py index 4d9c0ae..8b63307 100644 --- a/tests/deezer_search_test.py +++ b/tests/deezer_search_test.py @@ -39,33 +39,33 @@ async def test_url_artist(self): tracks = await self.deezer_search.get_by_url("https://www.deezer.com/fr/artist/1060") self.assertIsNotNone(tracks) - # async def test_url_artist_2nd_format(self): - # tracks = await self.deezer_search.get_by_url("https://deezer.page.link/HWapYqfpsmSukE6T7") - # self.assertIsNotNone(tracks) + async def test_url_artist_2nd_format(self): + tracks = await self.deezer_search.get_by_url("https://deezer.page.link/HWapYqfpsmSukE6T7") + self.assertIsNotNone(tracks) async def test_url_album(self): tracks = await self.deezer_search.get_by_url("https://www.deezer.com/fr/album/53012892") self.assertIsNotNone(tracks) - # async def test_url_album_2nd_format(self): - # tracks = await self.deezer_search.get_by_url("https://deezer.page.link/gvryHN1VUn62CnCJ7") - # self.assertIsNotNone(tracks) + async def test_url_album_2nd_format(self): + tracks = await self.deezer_search.get_by_url("https://deezer.page.link/gvryHN1VUn62CnCJ7") + self.assertIsNotNone(tracks) async def test_url_track(self): tracks = await self.deezer_search.get_by_url("https://www.deezer.com/fr/track/2308590") self.assertIsNotNone(tracks) - # async def test_url_track_2nd_format(self): - # tracks = await self.deezer_search.get_by_url("https://deezer.page.link/qF6ucYP2wSGsLiMB6") - # self.assertIsNotNone(tracks) + async def test_url_track_2nd_format(self): + tracks = await self.deezer_search.get_by_url("https://deezer.page.link/qF6ucYP2wSGsLiMB6") + self.assertIsNotNone(tracks) async def test_url_playlist(self): tracks = await self.deezer_search.get_by_url("https://www.deezer.com/fr/playlist/987181371") self.assertIsNotNone(tracks) - # async def test_url_playlist_2nd_format(self): - # tracks = await self.deezer_search.get_by_url("https://deezer.page.link/ibwojNjEKAQjsKgZ9") - # self.assertIsNotNone(tracks) + async def test_url_playlist_2nd_format(self): + tracks = await self.deezer_search.get_by_url("https://deezer.page.link/ibwojNjEKAQjsKgZ9") + self.assertIsNotNone(tracks) if __name__ == "__main__": unittest.main(failfast=True)