diff --git a/.github/workflows/compile.yml b/.github/workflows/compile.yml index 2a7a591..080e4c7 100644 --- a/.github/workflows/compile.yml +++ b/.github/workflows/compile.yml @@ -18,6 +18,11 @@ on: - "Dockerfile" - ".github/workflows/compile.yml" +env: + SRC: "./src/pyramid" + TEST: "./src/pyramid-test" + TEST_FILES: "*_test.py" + jobs: compile: @@ -53,24 +58,57 @@ jobs: - name: Test compilation run: | - python -m compileall ./src + python -m compileall ${{ env.SRC }} - name: Save version run: | - FULL_JSON=$(python src/pyramid --version) + FULL_JSON=$(python ${{ env.SRC }} --version) echo "json=$(echo $FULL_JSON | jq -c)" >> $GITHUB_OUTPUT echo "version=$(echo $FULL_JSON | jq -r '.version')" >> $GITHUB_OUTPUT echo "commit_id=$(echo $FULL_JSON | jq -r '.git_info.commit_id')" >> $GITHUB_OUTPUT echo "last_author=$(echo $FULL_JSON | jq -r '.git_info.last_author')" >> $GITHUB_OUTPUT id: version + unit_test: + name: "Unit tests Python" + needs: compile + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + clean: false + + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Cache dependencies + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + + - name: Units tests + run: | + python -m unittest discover -v -s ${{ env.TEST }} -p ${{ env.TEST_FILES }} + version_compatibility: name: "Compile Python" - runs-on: ubuntu-latest needs: compile + runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.12"] continue-on-error: true steps: @@ -86,9 +124,9 @@ jobs: uses: actions/cache@v3 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}-${{ matrix.python-version }} + key: ${{ runner.os }}-python_${{ matrix.python-version }}-pip-${{ hashFiles('**/requirements.txt') }}-${{ matrix.python-version }} restore-keys: | - ${{ runner.os }}-pip + ${{ runner.os }}-python_${{ matrix.python-version }}-pip - name: Install dependencies run: | @@ -97,11 +135,120 @@ jobs: - name: Test compilation run: | - python -m compileall ./src + python -m compileall ${{ env.SRC }} + + publish_release: + name: "Publish release" + needs: ["compile", "unit_test"] + runs-on: ubuntu-latest + outputs: + docker_tag: ${{ steps.docker_tag.outputs.tag }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 100 + + - name: Set Docker Image Tag + id: docker_tag + run: | + if [ ${{ github.ref }} = 'refs/heads/main' ]; then + echo "tag=latest" >> $GITHUB_OUTPUT + elif [ ${{ github.ref }} = 'refs/heads/pre-prod' ]; then + echo "tag=pre-prod" >> $GITHUB_OUTPUT + echo "tag_github=unstable" >> $GITHUB_OUTPUT + else + echo "tag=dev" >> $GITHUB_OUTPUT + fi + + - name: Get commit messages + id: get_commit_messages + run: | + COMMIT_MESSAGES=$(git log --pretty=format:"%s" ${{ github.event.before }}..${{ github.sha }} | sed -e 's/^/* /') + echo "commit_messages=${COMMIT_MESSAGES}" >> $GITHUB_OUTPUT + + - name: Get last release tag + if: steps.docker_tag.outputs.tag == 'latest' + id: get_last_release + run: | + RESPONSE=$(curl -s "https://api.github.com/repos/${{ github.repository }}/releases/latest") + if [[ $(echo "$RESPONSE" | jq -r .message) == "Not Found" ]]; then + LAST_RELEASE_TAG=$(git rev-list --max-parents=0 HEAD) + else + LAST_RELEASE_TAG=$(echo "$RESPONSE" | jq -r .tag_name) + fi + echo "last_release_tag=${LAST_RELEASE_TAG}" >> $GITHUB_OUTPUT + + - name: Create stable Release + if: steps.docker_tag.outputs.tag == 'latest' + uses: actions/create-release@latest + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ needs.compile.outputs.version }} + release_name: Stable Release v${{ needs.compile.outputs.version }} + body: | + This has been deployed on the Discord bot `PyRamid#6882`. + To use the latest version the bot, please refer to the instructions outlined at https://github.com/tristiisch/PyRamid/#usage. + + ## Changes + ${{ steps.get_commit_messages.outputs.commit_messages }} + + ## Docker + This version is now accessible through various Docker images. Each image creation corresponds to a unique snapshot of this version, while updating the image corresponds to using an updated Docker image tag. + + ### Images creation + * ${{ secrets.DOCKERHUB_USERNAME }}/pyramid:${{ needs.compile.outputs.version }}-${{ needs.compile.outputs.commit_id }} + + ### Images update + * ${{ secrets.DOCKERHUB_USERNAME }}/pyramid:${{ steps.docker_tag.outputs.tag }} + draft: true + prerelease: false + + - name: Get last pre-release tag + if: steps.docker_tag.outputs.tag != 'latest' + id: get_last_unstable_release + run: | + RESPONSE=$(curl -s "https://api.github.com/repos/${{ github.repository }}/releases") + if [[ $(echo "$RESPONSE" | jq -r .message) == "Not Found" ]]; then + LAST_RELEASE_TAG=${{ github.event.before }} + else + LAST_RELEASE_TAG=$(echo "$RESPONSE" | jq -r '.[0].tag_name') + fi + echo "last_release_tag=${LAST_RELEASE_TAG}" >> $GITHUB_OUTPUT + + - name: Create unstable Release + if: steps.docker_tag.outputs.tag != 'latest' + uses: actions/create-release@latest + env: + GITHUB_TOKEN: ${{ secrets.TOKEN_RELEASE }} + with: + tag_name: ${{ steps.docker_tag.outputs.tag_github }}-${{ needs.compile.outputs.version }}-${{ needs.compile.outputs.commit_id }} + release_name: Unstable release v${{ needs.compile.outputs.version }}-${{ needs.compile.outputs.commit_id }} + body: | + This has been deployed on the Discord bot `PyRamid PRÉ-PROD#6277`. + Feel free to put it to the test by joining the development server at https://discord.gg/pNrZp7U34d. + + ## Changes + ${{ steps.get_commit_messages.outputs.commit_messages }} + + ## Docker + This version is now accessible through various Docker images. Each image creation corresponds to a unique snapshot of this version, while updating the image corresponds to using an updated Docker image tag. + + ### Images creation + * ${{ secrets.DOCKERHUB_USERNAME }}/pyramid:${{ needs.compile.outputs.version }}-${{ needs.compile.outputs.commit_id }} + + ### Images update + * ${{ secrets.DOCKERHUB_USERNAME }}/pyramid:${{ steps.docker_tag.outputs.tag }} + * ${{ secrets.DOCKERHUB_USERNAME }}/pyramid:${{ needs.compile.outputs.version }} + + draft: false + prerelease: true + docker_push: name: "Docker Push" - needs: compile + needs: ["compile", "publish_release"] runs-on: ubuntu-latest if: github.event_name == 'push' @@ -122,25 +269,14 @@ jobs: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Set Docker Tag - run: | - if [ ${{ github.ref }} = 'refs/heads/main' ]; then - echo "tag=latest" >> $GITHUB_OUTPUT - elif [ ${{ github.ref }} = 'refs/heads/pre-prod' ]; then - echo "tag=pre-prod" >> $GITHUB_OUTPUT - else - echo "tag=dev" >> $GITHUB_OUTPUT - fi - id: set_tag - - name: Build and push PROD - if: steps.set_tag.outputs.tag == 'latest' + if: needs.publish_release.outputs.docker_tag == 'latest' uses: docker/build-push-action@v5 with: context: . push: true tags: | - ${{ secrets.DOCKERHUB_USERNAME }}/pyramid:${{ steps.set_tag.outputs.tag }}, + ${{ secrets.DOCKERHUB_USERNAME }}/pyramid:${{ needs.publish_release.outputs.docker_tag }}, ${{ secrets.DOCKERHUB_USERNAME }}/pyramid:${{ needs.compile.outputs.version }}, ${{ secrets.DOCKERHUB_USERNAME }}/pyramid:${{ needs.compile.outputs.version }}-${{ needs.compile.outputs.commit_id }} @@ -148,13 +284,13 @@ jobs: cache-to: type=gha,mode=max - name: Build and push DEV - if: steps.set_tag.outputs.tag != 'latest' + if: needs.publish_release.outputs.docker_tag != 'latest' uses: docker/build-push-action@v5 with: context: . push: true tags: | - ${{ secrets.DOCKERHUB_USERNAME }}/pyramid:${{ steps.set_tag.outputs.tag }}, + ${{ secrets.DOCKERHUB_USERNAME }}/pyramid:${{ needs.publish_release.outputs.docker_tag }}, ${{ secrets.DOCKERHUB_USERNAME }}/pyramid:${{ needs.compile.outputs.version }}-${{ needs.compile.outputs.commit_id }} cache-from: type=gha diff --git a/.gitignore b/.gitignore index 4b6b56b..a17f022 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,4 @@ songs/ config.yml logs/ git_info.json +/test/ \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 5cf979b..4839201 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,12 @@ { "python.analysis.typeCheckingMode": "basic", - "python.tensorBoard.logDirectory": "./logs" + "python.testing.unittestArgs": [ + "-v", + "-s", + "./src/pyramid-test", + "-p", + "*_test.py" + ], + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true } diff --git a/README.md b/README.md index 9092453..058f56c 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,9 @@ spotify: client_id: client_secret: +general: + default_limit_tracks: 100 + # Available value: production, pre-production, development # Change message level in logs mode: production diff --git a/auto_update.sh b/auto_update.sh new file mode 100644 index 0000000..f1e7783 --- /dev/null +++ b/auto_update.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +# Get the directory of the script +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +# Change the working directory +cd "$DIR" + +docker compose up --pull always -d diff --git a/config.exemple.yml b/config.exemple.yml index ee4e25c..1dfe802 100644 --- a/config.exemple.yml +++ b/config.exemple.yml @@ -34,6 +34,9 @@ spotify: client_id: client_secret: +general: + default_limit_tracks: 100 + # Available value: production, pre-production, development # Change message level in logs mode: production diff --git a/environnement.ps1 b/environnement.ps1 index 0cacc89..329deb1 100644 --- a/environnement.ps1 +++ b/environnement.ps1 @@ -9,6 +9,10 @@ function Install-Requirement() { pip install -r .\requirements.txt } +function Update-Requirement() { + pip install --upgrade -r .\requirements.txt +} + function Add-Lib($lib) { pip install $lib pip freeze | grep -i $lib >> requirements.txt diff --git a/requirements.txt b/requirements.txt index 4dd1f3b..049a7e0 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/src/pyramid-test/__main__.py b/src/pyramid-test/__main__.py new file mode 100644 index 0000000..16eea9a --- /dev/null +++ b/src/pyramid-test/__main__.py @@ -0,0 +1,6 @@ +import unittest +import queue_test + +if __name__ == "__main__": + suite = unittest.TestLoader().loadTestsFromModule(queue_test) + unittest.TextTestRunner(verbosity=2).run(suite) diff --git a/src/pyramid-test/queue_test.py b/src/pyramid-test/queue_test.py new file mode 100644 index 0000000..21ddb1e --- /dev/null +++ b/src/pyramid-test/queue_test.py @@ -0,0 +1,179 @@ +import os +import sys +import time +import unittest + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from pyramid.tools.queue import Queue, QueueItem # noqa: E402 + + +class SimpleQueue(unittest.TestCase): + def test_add(self): + queue = Queue(threads=1) + self.assertEqual(queue.length(), 0) + + item = QueueItem(name="test", func=lambda x: x, x=5) + queue.add(item) + self.assertEqual(queue.length(), 1) + + def test_add_at_start(self): + queue = Queue(threads=1) + self.assertEqual(queue.length(), 0) + + item = QueueItem(name="test", func=lambda x: x, x=5) + queue.add_at_start(item) + self.assertEqual(queue.length(), 1) + + def test_worker_start_before(self): + queue = Queue(threads=1) + self.assertEqual(queue.length(), 0) + + queue.start() + item = QueueItem(name="test", func=lambda x: x, x=5) + queue.add(item) + self.assertEqual(queue.length(), 1) + + queue.end() + queue.join() + self.assertEqual(queue.length(), 0) + + def test_worker_start_after(self): + queue = Queue(threads=1) + self.assertEqual(queue.length(), 0) + + item = QueueItem(name="test", func=lambda x: x, x=5) + queue.add(item) + self.assertEqual(queue.length(), 1) + + queue.start() + queue.end() + queue.join() + self.assertEqual(queue.length(), 0) + + def test_wait_for_end(self): + queue = Queue(threads=1) + queue.register_to_wait_on_exit() + queue.start() + + item = QueueItem(name="test", func=lambda x: x, x=5) + queue.add(item) + + Queue.wait_for_end(1) + + +class MediumQueue(unittest.TestCase): + def test_order_simple(self): + thread_nb = 1 + queue = Queue(threads=thread_nb) + results = [] + results_excepted = list(range(1, 10)) + + for i in range(1, 10): + item = QueueItem( + f"test{i}", lambda n : n, None, lambda result: results.append(result), n=i + ) + queue.add(item) + + queue.start() + queue.end() + queue.join() + + self.assertEqual(results, results_excepted) + + def test_order_reverse(self): + thread_nb = 1 + queue = Queue(threads=thread_nb) + results = [] + results_excepted = list(range(9, 0, -1)) + + for i in range(1, 10): + item = QueueItem( + f"test{i}", lambda n : n, None, lambda result: results.append(result), n=i + ) + queue.add_at_start(item) + + queue.start() + queue.end() + queue.join() + + self.assertEqual(results, results_excepted) + + def test_order_mixed(self): + thread_nb = 1 + queue = Queue(threads=thread_nb) + results = [] + results_excepted = list(range(1, 100)) + + for i in range(10, 20): + item = QueueItem( + f"test{i}", lambda n : n, None, lambda result: results.append(result), n=i + ) + queue.add(item) + + for i in range(9, 0, -1): + item = QueueItem( + f"test{i}", lambda n : n, None, lambda result: results.append(result), n=i + ) + queue.add_at_start(item) + + for i in range(20, 100): + item = QueueItem( + f"test{i}", lambda n : n, None, lambda result: results.append(result), n=i + ) + queue.add(item) + + queue.start() + queue.end() + queue.join() + + self.assertEqual(results, results_excepted) + + def test_order_multi_thread(self): + thread_nb = 10 + items = 100 + queue = Queue(threads=thread_nb) + results = [] + results_excepted = list(range(1, items)) + + def sleep_and_return_n(n): + time.sleep(n / 1000) + return n + + for i in range(1, items): + item = QueueItem( + f"test{i}", sleep_and_return_n, None, lambda result: results.append(result), n=i + ) + queue.add(item) + + queue.start() + queue.end() + queue.join() + + self.assertEqual(results, results_excepted) + + def test_wait_for_end_shutdown_threads(self): + thread_nb = 2 + timeout_per_thread = 1 + items = 100 + + queue = Queue(threads=thread_nb) + queue.register_to_wait_on_exit() + + for i in range(items): + item = QueueItem(name=f"test{i}", func=lambda: time.sleep(60)) + queue.add(item) + + self.assertEqual(queue.length(), items) + queue.start() + start_time = time.time() + + queue.join(timeout_per_thread) + + end_time = time.time() + elapsed_time = end_time - start_time + self.assertLessEqual(elapsed_time, thread_nb * timeout_per_thread + 1) + + +if __name__ == "__main__": + unittest.main(failfast=True) diff --git a/src/pyramid/__init__.py b/src/pyramid/__init__.py new file mode 100644 index 0000000..c0ec412 --- /dev/null +++ b/src/pyramid/__init__.py @@ -0,0 +1 @@ +name = "pyramid" \ No newline at end of file diff --git a/src/pyramid/connector/deezer/search.py b/src/pyramid/connector/deezer/search.py index db58ad0..70c0ee2 100644 --- a/src/pyramid/connector/deezer/search.py +++ b/src/pyramid/connector/deezer/search.py @@ -11,17 +11,16 @@ from deezer.client import DeezerErrorResponse from data.track import TrackMinimalDeezer -from data.a_search import ASearch +from data.a_search import ASearch, ASearchId +from data.a_engine_tools import AEngineTools -DEFAULT_LIMIT = 100 - -class DeezerSearch(ASearch): - def __init__(self): +class DeezerSearch(ASearchId, ASearch): + def __init__(self, default_limit: int): + self.default_limit = default_limit self.client = Client() self.tools = DeezerTools() self.strict = False - self.default_limit = DEFAULT_LIMIT def search_track(self, search) -> TrackMinimalDeezer | None: search_results = self.client.search(query=search) @@ -36,7 +35,10 @@ def get_track_by_id(self, track_id: int) -> TrackMinimalDeezer | None: return None return TrackMinimalDeezer(track) - def search_tracks(self, search, limit=DEFAULT_LIMIT) -> list[TrackMinimalDeezer] | None: + def search_tracks(self, search, limit: int | None = None) -> list[TrackMinimalDeezer] | None: + if limit is None: + limit = self.default_limit + search_results = self.client.search(query=search, strict=self.strict) if not search_results or len(search_results) == 0: @@ -109,7 +111,9 @@ def get_album_tracks_by_id( return None return [TrackMinimalDeezer(element) for element in album.get_tracks()], [] - def get_top_artist(self, artist_name, limit=DEFAULT_LIMIT) -> list[TrackMinimalDeezer] | None: + def get_top_artist(self, artist_name, limit: int | None = None) -> list[TrackMinimalDeezer] | None: + if limit is None: + limit = self.default_limit search_results = self.client.search_artists(query=artist_name, strict=self.strict) if not search_results or len(search_results) == 0: return None @@ -118,8 +122,10 @@ def get_top_artist(self, artist_name, limit=DEFAULT_LIMIT) -> list[TrackMinimalD return [TrackMinimalDeezer(element) for element in top_tracks] def get_top_artist_by_id( - self, artist_id: int, limit=DEFAULT_LIMIT + self, artist_id: int, limit: int | None = None ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | None: + if limit is None: + limit = self.default_limit artist = self.client.get_artist(artist_id) # TODO handle HTTP errors if not artist: return None @@ -129,7 +135,7 @@ def get_top_artist_by_id( async def get_by_url( self, url ) -> tuple[list[TrackMinimalDeezer], list[TrackMinimalDeezer]] | TrackMinimalDeezer | None: - id, type = self.tools.extract_deezer_info(url) + id, type = self.tools.extract_from_url(url) if id is None: return None @@ -225,8 +231,8 @@ class DeezerType(Enum): TRACK = 4 -class DeezerTools: - def extract_deezer_info(self, url) -> tuple[int, DeezerType | None] | tuple[None, None]: +class DeezerTools(AEngineTools): + def extract_from_url(self, url) -> tuple[int, DeezerType | None] | tuple[None, None]: # Resolve if URL is a deezer.page.link URL if "deezer.page.link" in url: response = requests.get(url, allow_redirects=True) @@ -235,20 +241,19 @@ def extract_deezer_info(self, url) -> tuple[int, DeezerType | None] | tuple[None # Extract ID and type using regex pattern = r"(?<=deezer.com/fr/)(\w+)/(?P\d+)" match = re.search(pattern, url) - if match: - deezer_type_str = match.group(1).upper() - if deezer_type_str == "PLAYLIST": - deezer_type = DeezerType.PLAYLIST - elif deezer_type_str == "ARTIST": - deezer_type = DeezerType.ARTIST - elif deezer_type_str == "ALBUM": - deezer_type = DeezerType.ALBUM - elif deezer_type_str == "TRACK": - deezer_type = DeezerType.TRACK - else: - deezer_type = None - - deezer_id = int(match.group("id")) - return deezer_id, deezer_type - else: + if not match: return None, None + deezer_type_str = match.group(1).upper() + if deezer_type_str == "PLAYLIST": + deezer_type = DeezerType.PLAYLIST + elif deezer_type_str == "ARTIST": + deezer_type = DeezerType.ARTIST + elif deezer_type_str == "ALBUM": + deezer_type = DeezerType.ALBUM + elif deezer_type_str == "TRACK": + deezer_type = DeezerType.TRACK + else: + deezer_type = None + + deezer_id = int(match.group("id")) + return deezer_id, deezer_type diff --git a/src/pyramid/connector/discord/bot.py b/src/pyramid/connector/discord/bot.py index 9a3c4b1..1736070 100644 --- a/src/pyramid/connector/discord/bot.py +++ b/src/pyramid/connector/discord/bot.py @@ -13,16 +13,13 @@ from discord.ext.commands.errors import CommandNotFound, MissingPermissions, MissingRequiredArgument from data.functional.application_info import ApplicationInfo -from data.a_search import ASearch from data.environment import Environment from data.guild_data import GuildData -from connector.deezer.downloader import DeezerDownloader -from connector.deezer.search import DeezerSearch from connector.discord.bot_cmd import BotCmd from connector.discord.bot_listener import BotListener from connector.discord.guild_cmd import GuildCmd from connector.discord.guild_queue import GuildQueue -from connector.spotify.search import SpotifySearch +from data.functional.engine_source import EngineSource from tools.configuration import Configuration @@ -31,21 +28,14 @@ def __init__( self, logger: logging.Logger, information: ApplicationInfo, - config: Configuration, - deezer_dl: DeezerDownloader, + config: Configuration ): self.__logger = logger self.__information = information self.__token = config.discord_token self.__ffmpeg = config.discord_ffmpeg self.__environment: Environment = config.mode - self.__deezer_dl = deezer_dl - self.__search_engines: Dict[str, ASearch] = dict( - { - "spotify": SpotifySearch(config.spotify_client_id, config.spotify_client_secret), - "deezer": DeezerSearch(), - } - ) + self.__engine_source = EngineSource(config) self.__started = time.time() intents = discord.Intents.default() @@ -110,9 +100,8 @@ def __get_guild_cmd(self, guild: Guild) -> GuildCmd: self.guilds_instances[guild.id] = GuildInstances( guild, self.__logger.getChild(guild.name), - self.__deezer_dl, - self.__ffmpeg, - self.__search_engines, + self.__engine_source, + self.__ffmpeg ) return self.guilds_instances[guild.id].cmds @@ -123,10 +112,9 @@ def __init__( self, guild: Guild, logger: Logger, - deezer_downloader: DeezerDownloader, - ffmpeg_path: str, - search_engines: Dict[str, ASearch], + engine_source: EngineSource, + ffmpeg_path: str ): - self.data = GuildData(guild, search_engines) + self.data = GuildData(guild, engine_source) self.songs = GuildQueue(self.data, ffmpeg_path) - self.cmds = GuildCmd(logger, self.data, self.songs, deezer_downloader) + self.cmds = GuildCmd(logger, self.data, self.songs, engine_source) diff --git a/src/pyramid/connector/discord/bot_cmd.py b/src/pyramid/connector/discord/bot_cmd.py index 5f2cf0e..06d4ef2 100644 --- a/src/pyramid/connector/discord/bot_cmd.py +++ b/src/pyramid/connector/discord/bot_cmd.py @@ -285,14 +285,14 @@ async def cmd_play_url_next(ctx: Interaction, url: str): await guild_cmd.play_url(ms, ctx, url, at_end=False) - @bot.tree.command(name="spam", description="Test spam") - async def cmd_spam(ctx: Interaction): - ms = MessageSenderQueued(ctx) - await ms.waiting() - - for i in range(100): - ms.add_message(f"Spam n°{i}") - await ctx.response.send_message("Spam ended") + # @bot.tree.command(name="spam", description="Test spam") + # async def cmd_spam(ctx: Interaction): + # ms = MessageSenderQueued(ctx) + # await ms.waiting() + + # for i in range(100): + # ms.add_message(f"Spam n°{i}") + # await ctx.response.send_message("Spam ended") async def __use_on_guild_only(self, ctx: Interaction) -> bool: if ctx.guild is None: diff --git a/src/pyramid/connector/discord/guild_cmd.py b/src/pyramid/connector/discord/guild_cmd.py index b71376f..ddda32b 100644 --- a/src/pyramid/connector/discord/guild_cmd.py +++ b/src/pyramid/connector/discord/guild_cmd.py @@ -5,10 +5,10 @@ import data.tracklist as utils_list_track from data.guild_data import GuildData from data.track import TrackMinimal -from connector.deezer.downloader import DeezerDownloader from connector.discord.guild_cmd_tools import GuildCmdTools from connector.discord.guild_queue import GuildQueue from data.functional.messages.message_sender_queued import MessageSenderQueued +from data.functional.engine_source import EngineSource class GuildCmd(GuildCmdTools): @@ -17,10 +17,10 @@ def __init__( logger: Logger, guild_data: GuildData, guild_queue: GuildQueue, - deezer_dl: DeezerDownloader, + engine_source: EngineSource, ): self.logger = logger - self.deezer_dl = deezer_dl + self.engine_source = engine_source self.data = guild_data self.queue = guild_queue @@ -31,7 +31,7 @@ async def play(self, ms: MessageSenderQueued, ctx: Interaction, input: str, at_e ms.response_message(content=f"Searching **{input}**") - track: TrackMinimal | None = self.data.search_engine.search_track(input) + track: TrackMinimal | None = self.data.search_engine.default_engine.search_track(input) if not track: ms.response_message(content=f"**{input}** not found.") return False @@ -178,9 +178,9 @@ def search( self, ms: MessageSenderQueued, ctx: Interaction, input: str, engine: str | None ) -> bool: if engine is None: - search_engine = self.data.search_engine + search_engine = self.data.search_engine.default_engine else: - test_value = self.data.search_engines.get(engine.lower()) + test_value = self.data.search_engines.get_engine(engine) if not test_value: ms.response_message(content=f"Search engine **{engine}** not found.") return False @@ -204,7 +204,7 @@ async def play_multiple(self, ms: MessageSenderQueued, ctx: Interaction, input: ms.response_message(content=f"Searching **{input}** ...") - tracks: list[TrackMinimal] | None = self.data.search_engine.search_tracks(input) + tracks: list[TrackMinimal] | None = self.data.search_engine.default_engine.search_tracks(input) if not tracks: ms.response_message(content=f"**{input}** not found.") return False @@ -218,21 +218,20 @@ async def play_url(self, ms: MessageSenderQueued, ctx: Interaction, url: str, at ms.response_message(content=f"Searching **{url}** ...") - # ctx.client.loop - res: ( + result: ( tuple[list[TrackMinimal], list[TrackMinimal]] | TrackMinimal | None - ) = await self.data.search_engine.get_by_url(url) - if not res: + ) = await self.data.search_engine.search_by_url(url) + if not result: ms.response_message(content=f"**{url}** not found.") return False - if isinstance(res, tuple): - tracks, tracks_unfindable = res + if isinstance(result, tuple): + tracks, tracks_unfindable = result return await self._execute_play_multiple( ms, voice_channel, tracks, tracks_unfindable, at_end=at_end ) - elif isinstance(res, TrackMinimal): - tracks = res + elif isinstance(result, TrackMinimal): + tracks = result return await self._execute_play(ms, voice_channel, tracks, at_end=at_end) else: raise Exception("Unknown type 'res'") diff --git a/src/pyramid/connector/discord/guild_cmd_tools.py b/src/pyramid/connector/discord/guild_cmd_tools.py index 99ef77b..4faef32 100644 --- a/src/pyramid/connector/discord/guild_cmd_tools.py +++ b/src/pyramid/connector/discord/guild_cmd_tools.py @@ -3,9 +3,9 @@ from data.track import Track, TrackMinimal from data.guild_data import GuildData from data.tracklist import TrackList -from connector.deezer.downloader import DeezerDownloader from connector.discord.guild_queue import GuildQueue from data.functional.messages.message_sender_queued import MessageSenderQueued +from data.functional.engine_source import EngineSource class GuildCmdTools: @@ -13,9 +13,9 @@ def __init__( self, guild_data: GuildData, guild_queue: GuildQueue, - deezer_dl: DeezerDownloader, + engine_source: EngineSource, ): - self.deezer_dl = deezer_dl + self.engine_source = engine_source self.data = guild_data self.queue = guild_queue @@ -91,9 +91,9 @@ async def _execute_play_multiple( cant_dl = 0 for i, track in enumerate(tracks): - track_downloaded: Track | None = await self.deezer_dl.dl_track_by_id(track.id) + track_downloaded: Track | None = await self.engine_source.download_track(track) if not track_downloaded: - ms.add_message(content=f"**{track.get_full_name()}** can't be downloaded.") + ms.add_message(content=f"ERROR > **{track.get_full_name()}** can't be downloaded.") cant_dl += 1 continue if ( @@ -102,7 +102,7 @@ async def _execute_play_multiple( or tl.add_track_after(track_downloaded)) ): ms.add_message( - content=f"**{track.get_full_name()}** can't be add to the queue." + content=f"ERROR > **{track.get_full_name()}** can't be add to the queue." ) cant_dl += 1 continue @@ -129,9 +129,9 @@ async def _execute_play( tl: TrackList = self.data.track_list ms.response_message(content=f"**{track.get_full_name()}** found ! Downloading ...") - track_downloaded: Track | None = await self.deezer_dl.dl_track_by_id(track.id) + track_downloaded: Track | None = await self.engine_source.download_track(track) if not track_downloaded: - ms.response_message(content=f"**{track.get_full_name()}** can't be downloaded.") + ms.response_message(content=f"ERROR > **{track.get_full_name()}** can't be downloaded.") return False if ( @@ -139,7 +139,7 @@ async def _execute_play( and not (tl.add_track(track_downloaded) or tl.add_track_after(track_downloaded)) ): - ms.add_message(content=f"**{track.get_full_name()}** can't be add to the queue.") + ms.add_message(content=f"ERROR > **{track.get_full_name()}** can't be add to the queue.") return False await self.queue.goto_channel(voice_channel) diff --git a/src/pyramid/connector/discord/guild_queue.py b/src/pyramid/connector/discord/guild_queue.py index d79b0fc..e833aad 100644 --- a/src/pyramid/connector/discord/guild_queue.py +++ b/src/pyramid/connector/discord/guild_queue.py @@ -8,7 +8,7 @@ from data.tracklist import TrackList from data.guild_data import GuildData from connector.discord.music_player_interface import MusicPlayerInterface -from data.functional.messages.message_sender import MessageSender +from data.functional.messages.message_sender_queued import MessageSenderQueued class GuildQueue: @@ -35,7 +35,7 @@ async def goto_channel(self, voice_channel: VoiceChannel) -> bool: return True return False - async def play(self, msg_sender: MessageSender) -> bool: + async def play(self, msg_sender: MessageSenderQueued) -> bool: tl: TrackList = self.data.track_list vc: VoiceClient = self.data.voice_client @@ -164,14 +164,14 @@ def queue_list(self) -> str | None: return humain_str_array # Called after song played - async def __song_end(self, err: Exception | None, msg_sender: MessageSender): + async def __song_end(self, err: Exception | None, msg_sender: MessageSenderQueued): if err is not None: - await msg_sender.add_message(f"An error occurred while playing song: {err}") + msg_sender.add_message(f"An error occurred while playing song: {err}") await self.song_end_action(err, msg_sender) self.song_end_action = self.__song_end_continue - async def __song_end_continue(self, err: Exception | None, msg_sender: MessageSender): + async def __song_end_continue(self, err: Exception | None, msg_sender: MessageSenderQueued): tl: TrackList = self.data.track_list vc: VoiceClient = self.data.voice_client @@ -180,11 +180,11 @@ async def __song_end_continue(self, err: Exception | None, msg_sender: MessageSe if tl.is_empty(): await vc.disconnect() - await msg_sender.response_message(content="Bye bye") + msg_sender.response_message(content="Bye bye") else: await self.play(msg_sender) - async def __song_end_next(self, err: Exception | None, msg_sender: MessageSender): + async def __song_end_next(self, err: Exception | None, msg_sender: MessageSenderQueued): tl: TrackList = self.data.track_list if tl.is_empty() is False: @@ -194,7 +194,7 @@ async def __song_end_next(self, err: Exception | None, msg_sender: MessageSender raise Exception("They is no song to play after") await self.play(msg_sender) - async def __song_end_stop(self, err: Exception | None, msg_sender: MessageSender): + async def __song_end_stop(self, err: Exception | None, msg_sender: MessageSenderQueued): tl: TrackList = self.data.track_list tl.clear() diff --git a/src/pyramid/connector/spotify/search.py b/src/pyramid/connector/spotify/search.py index aa2f6bd..1a7e4f2 100644 --- a/src/pyramid/connector/spotify/search.py +++ b/src/pyramid/connector/spotify/search.py @@ -1,18 +1,72 @@ -import spotipy +import re +from enum import Enum -from spotipy.oauth2 import SpotifyClientCredentials +import spotipy +from data.a_engine_tools import AEngineTools +from data.a_search import ASearch, ASearchId from data.track import TrackMinimalSpotify -from data.a_search import ASearch +from spotipy.oauth2 import SpotifyClientCredentials -class SpotifySearch(ASearch): - def __init__(self, client_id, client_secret): +class SpotifySearchBase(ASearch): + def __init__(self, default_limit: int, client_id: str, client_secret: str): + self.default_limit = default_limit self.client_id = client_id self.client_secret = client_secret self.client_credentials_manager = SpotifyClientCredentials( client_id=self.client_id, client_secret=self.client_secret ) self.client = spotipy.Spotify(client_credentials_manager=self.client_credentials_manager) + self.tools = SpotifyTools() + + +class SpotifySearchId(ASearchId, SpotifySearchBase): + def __init__(self, default_limit: int, client_id: str, client_secret: str): + super().__init__(default_limit, client_id, client_secret) + + + def get_track_by_id(self, track_id: str) -> TrackMinimalSpotify | None: + results = self.client.track(track_id=track_id) + + return TrackMinimalSpotify(results) + + def get_playlist_tracks_by_id( + self, playlist_id: str + ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | None: + results = self.client.playlist(playlist_id=playlist_id) + + if not results or not results.get("tracks") or not results["tracks"].get("items"): + return None + + tracks = results["tracks"]["items"] + return [TrackMinimalSpotify(element["track"]) for element in tracks], [] + + def get_album_tracks_by_id( + self, album_id: str + ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | None: + results = self.client.album(album_id=album_id) + + if not results or not results.get("tracks") or not results["tracks"].get("items"): + return None + + tracks = results["tracks"]["items"] + return [TrackMinimalSpotify(element["track"]) for element in tracks], [] + + def get_top_artist_by_id( + self, artist_id: str, limit: int | None = None + ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | None: + results = self.client.artist_top_tracks(artist_id=artist_id) + + if not results or not results.get("tracks") or not results["tracks"].get("items"): + return None + + tracks = results["tracks"]["items"] + return [TrackMinimalSpotify(element["track"]) for element in tracks], [] + + +class SpotifySearch(SpotifySearchId): + def __init__(self, default_limit: int, client_id: str, client_secret: str): + super().__init__(default_limit, client_id, client_secret) def search_tracks(self, search, limit=10) -> list[TrackMinimalSpotify] | None: results = self.client.search(q=search, limit=limit, type="track") @@ -64,4 +118,56 @@ def get_top_artist(self, artist_name, limit=10) -> list[TrackMinimalSpotify] | N async def get_by_url( self, url ) -> tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | TrackMinimalSpotify | None: - raise NotImplementedError("Get by url for spotify is not implemted") + id, type = self.tools.extract_from_url(url) + + if id is None: + return None + if type is None: + raise NotImplementedError(f"The type of spotify info '{url}' is not implemented") + + tracks: ( + tuple[list[TrackMinimalSpotify], list[TrackMinimalSpotify]] | TrackMinimalSpotify | None + ) + + if type == SpotifyType.PLAYLIST: + tracks = self.get_playlist_tracks_by_id(id) + elif type == SpotifyType.ARTIST: + tracks = self.get_top_artist_by_id(id) + elif type == SpotifyType.ALBUM: + tracks = self.get_album_tracks_by_id(id) + elif type == SpotifyType.TRACK: + tracks = self.get_track_by_id(id) + else: + raise NotImplementedError(f"The type of spotify info '{type}' can't be resolve") + + return tracks + + +class SpotifyType(Enum): + PLAYLIST = 1 + ARTIST = 2 + ALBUM = 3 + TRACK = 4 + + +class SpotifyTools(AEngineTools): + def extract_from_url(self, url) -> tuple[str, SpotifyType | None] | tuple[None, None]: + # Extract ID and type using regex + pattern = r"(?<=open.spotify.com/)(\w+)/(\w+)" + match = re.search(pattern, url) + if not match: + return None, None + type_str = match.group(1).upper() + if type_str == "PLAYLIST": + type = SpotifyType.PLAYLIST + elif type_str == "ARTIST": + type = SpotifyType.ARTIST + elif type_str == "ALBUM": + type = SpotifyType.ALBUM + elif type_str == "TRACK": + type = SpotifyType.TRACK + else: + type = None + + id = match.group(2) + return id, type diff --git a/src/pyramid/data/a_engine_tools.py b/src/pyramid/data/a_engine_tools.py new file mode 100644 index 0000000..fbe2982 --- /dev/null +++ b/src/pyramid/data/a_engine_tools.py @@ -0,0 +1,9 @@ +import abc +from abc import ABC +from typing import Any + +class AEngineTools(ABC): + + @abc.abstractmethod + def extract_from_url(self, url) -> tuple[int | str, Any | None] | tuple[None, None]: + pass diff --git a/src/pyramid/data/a_search.py b/src/pyramid/data/a_search.py index 77afdcd..d35213e 100644 --- a/src/pyramid/data/a_search.py +++ b/src/pyramid/data/a_search.py @@ -30,3 +30,27 @@ async def get_by_url( self, url ) -> tuple[list[TrackMinimal], list[TrackMinimal]] | TrackMinimal | None: pass + + +class ASearchId(ABC): + @abc.abstractmethod + def get_track_by_id(self, track_id: int | str) -> TrackMinimal | None: + pass + + @abc.abstractmethod + def get_playlist_tracks_by_id( + self, playlist_id: int | str + ) -> tuple[list[TrackMinimal], list[TrackMinimal]] | None: + pass + + @abc.abstractmethod + def get_album_tracks_by_id( + self, album_id: int | str + ) -> tuple[list[TrackMinimal], list[TrackMinimal]] | None: + pass + + @abc.abstractmethod + def get_top_artist_by_id( + self, artist_id: int | str, limit: int | None = None + ) -> tuple[list[TrackMinimal], list[TrackMinimal]] | None: + pass diff --git a/src/pyramid/data/functional/application_info.py b/src/pyramid/data/functional/application_info.py index 3f19f5f..4f8f6e0 100644 --- a/src/pyramid/data/functional/application_info.py +++ b/src/pyramid/data/functional/application_info.py @@ -9,7 +9,7 @@ class ApplicationInfo: def __init__(self): self.name = "pyramid" self.os = get_os().lower() - self.version = "0.1.9" + self.version = "0.2.0" self.git_info = GitInfo() def load_git_info(self): diff --git a/src/pyramid/data/functional/engine_source.py b/src/pyramid/data/functional/engine_source.py new file mode 100644 index 0000000..54c055b --- /dev/null +++ b/src/pyramid/data/functional/engine_source.py @@ -0,0 +1,53 @@ +from typing import Dict +from connector.deezer.downloader import DeezerDownloader +from connector.deezer.search import DeezerSearch +from connector.spotify.search import SpotifySearch +from data.a_search import ASearch +from data.track import Track, TrackMinimal, TrackMinimalDeezer +from tools.configuration import Configuration + + +class EngineSource: + def __init__(self, config: Configuration): + self.__downloader = DeezerDownloader(config.deezer_arl, config.deezer_folder) + self.__deezer_search = DeezerSearch(config.default_limit_track) + self.__spotify_search = SpotifySearch( + config.default_limit_track, config.spotify_client_id, config.spotify_client_secret + ) + self.default_engine: ASearch = self.__deezer_search + self.__downloader_search = self.__deezer_search + self.__search_engines: Dict[str, ASearch] = dict( + { + "spotify": self.__spotify_search, + "deezer": self.__deezer_search, + } + ) + + async def download_track(self, track: TrackMinimal) -> Track | None: + track_used: TrackMinimalDeezer + + if not isinstance(track, TrackMinimalDeezer): + t = self.__downloader_search.search_exact_track(track.author_name, None, track.name) + if t is None: + return None + track_used = t + else: + track_used = track + + return await self.__downloader.dl_track_by_id(track_used.id) + + async def search_by_url(self, url: str): + """ + Search for tracks by URL using multiple search engines. + + :param url: The URL to search for. + """ + for engine in self.__search_engines.values(): + result = await engine.get_by_url(url) + if result is not None: + return result + + return None + + def get_engine(self, name: str): + return self.__search_engines.get(name.lower()) diff --git a/src/pyramid/data/functional/main.py b/src/pyramid/data/functional/main.py index 20ebf5a..50a3dbf 100644 --- a/src/pyramid/data/functional/main.py +++ b/src/pyramid/data/functional/main.py @@ -6,7 +6,6 @@ import tools.utils as tools from data.functional.application_info import ApplicationInfo -from connector.deezer.downloader import DeezerDownloader from connector.discord.bot import DiscordBot from tools.configuration import Configuration from tools.logs_handler import LogsHandler @@ -65,12 +64,9 @@ def clean_data(self): tools.clear_directory(self._config.deezer_folder) def start(self): - # Create Deezer player instance - deezer_dl = DeezerDownloader(self._config.deezer_arl, self._config.deezer_folder) - # Discord Bot Instance discord_bot = DiscordBot( - self.logger.getChild("Discord"), self._info, self._config, deezer_dl + self.logger.getChild("Discord"), self._info, self._config ) # Create bot discord_bot.create() diff --git a/src/pyramid/data/functional/messages/message_sender.py b/src/pyramid/data/functional/messages/message_sender.py index 01c09d3..ccb0116 100644 --- a/src/pyramid/data/functional/messages/message_sender.py +++ b/src/pyramid/data/functional/messages/message_sender.py @@ -9,14 +9,10 @@ MAX_MSG_LENGTH = 2000 -# queue = Queue(1, "MessageSender") -# queue.start() -# queue.register_to_wait_on_exit() - class MessageSender: def __init__(self, ctx: Interaction): - self.__ctx = ctx + self.ctx = ctx if ctx.channel is None: raise NotImplementedError("Unable to create a MessageSender without channel") if not isinstance(ctx.channel, TextChannel): @@ -25,12 +21,6 @@ def __init__(self, ctx: Interaction): self.last_reponse: Message | WebhookMessage | None = None self.loop: asyncio.AbstractEventLoop = ctx.client.loop - """ - Add a message as a response or follow-up. If no message has been sent yet, the message is sent as a response. - Otherwise, the message will be linked to the response (sent as a follow-up message). - If the message exceeds the maximum character limit, it will be truncated. - """ - async def add_message( # def add_message( self, @@ -38,6 +28,11 @@ async def add_message( callback: Callable | None = None, # ) -> None: ) -> Message: + """ + Add a message as a response or follow-up. If no message has been sent yet, the message is sent as a response. + Otherwise, the message will be linked to the response (sent as a follow-up message). + If the message exceeds the maximum character limit, it will be truncated. + """ if content != MISSING and content != "": new_content, is_used = tools.substring_with_end_msg( content, MAX_MSG_LENGTH, "{} more characters..." @@ -45,41 +40,25 @@ async def add_message( if is_used: content = new_content - if not self.__ctx.response.is_done(): + if not self.ctx.response.is_done(): msg = await self.txt_channel.send(content) - # queue.add( - # QueueItem( - # "Send reponse", self.txt_channel.send, self.loop, callback, content=content - # ) - # ) else: - msg = await self.__ctx.followup.send( + msg = await self.ctx.followup.send( content, wait=True, ) - # queue.add( - # QueueItem( - # "Send followup", - # self.__ctx.followup.send, - # self.loop, - # callback, - # content=content, - # wait=True, - # ) - # ) return msg - """ - Send a message as a response. If the response has already been sent, it will be modified. - If it is not possible to modify it, a new message will be sent as a follow-up. - If the message exceeds the maximum character limit, it will be truncated. - """ - async def response_message( # def response_message( self, content: str = MISSING, ): + """ + Send a message as a response. If the response has already been sent, it will be modified. + If it is not possible to modify it, a new message will be sent as a follow-up. + If the message exceeds the maximum character limit, it will be truncated. + """ if content != MISSING and content != "": new_content, is_used = tools.substring_with_end_msg( content, MAX_MSG_LENGTH, "{} more characters..." @@ -90,9 +69,9 @@ async def response_message( if self.last_reponse is not None: self.last_reponse.edit(content=content) - elif self.__ctx.response.is_done(): + elif self.ctx.response.is_done(): try: - await self.__ctx.edit_original_response( + await self.ctx.edit_original_response( content=content, ) # queue.add( @@ -113,7 +92,7 @@ async def response_message( else: raise err else: - await self.__ctx.response.send_message( + await self.ctx.response.send_message( content=content, ) # queue.add( @@ -125,11 +104,10 @@ async def response_message( # ) # ) - """ - Send a message with markdown code formatting. If the character limit is exceeded, send multiple messages. - """ - async def add_code_message(self, content: str, prefix=None, suffix=None): + """ + Send a message with markdown code formatting. If the character limit is exceeded, send multiple messages. + """ # def add_code_message(self, content: str, prefix=None, suffix=None): max_length = MAX_MSG_LENGTH if prefix is None: @@ -145,11 +123,11 @@ async def add_code_message(self, content: str, prefix=None, suffix=None): substrings_generator = tools.split_string_by_length(content, max_length) - if not self.__ctx.response.is_done(): + if not self.ctx.response.is_done(): first_substring = next(substrings_generator, None) if first_substring is not None: first_substring_formatted = f"```{first_substring}```" - await self.__ctx.response.send_message(content=first_substring_formatted) + await self.ctx.response.send_message(content=first_substring_formatted) # queue.add( # QueueItem( # "Send code as response", @@ -161,7 +139,7 @@ async def add_code_message(self, content: str, prefix=None, suffix=None): for substring in substrings_generator: substring_formatted = f"```{substring}```" - await self.__ctx.followup.send(content=substring_formatted) + await self.ctx.followup.send(content=substring_formatted) # queue.add( # QueueItem( # "Send code as followup", diff --git a/src/pyramid/data/functional/messages/message_sender_queued.py b/src/pyramid/data/functional/messages/message_sender_queued.py index 00d535e..e147d97 100644 --- a/src/pyramid/data/functional/messages/message_sender_queued.py +++ b/src/pyramid/data/functional/messages/message_sender_queued.py @@ -1,8 +1,10 @@ +import logging from typing import Callable +import tools.utils as tools +from data.functional.messages.message_sender import MessageSender from discord import Interaction from discord.utils import MISSING -from data.functional.messages.message_sender import MessageSender from tools.queue import Queue, QueueItem MAX_MSG_LENGTH = 2000 @@ -14,6 +16,7 @@ class MessageSenderQueued(MessageSender): def __init__(self, ctx: Interaction): + self.ctx = ctx super().__init__(ctx) async def waiting(self): @@ -34,6 +37,53 @@ def response_message( self, content: str = MISSING, ): + if content != MISSING and content != "": + new_content, is_used = tools.substring_with_end_msg( + content, MAX_MSG_LENGTH, "{} more characters..." + ) + if is_used: + content = new_content + + if self.last_reponse is not None: + queue.add( + QueueItem( + "Edit last response", + self.last_reponse.edit, + self.loop, + content=content, + ) + ) + + elif self.ctx.response.is_done(): + def on_error(err): + if err.code == 50027: # 401 Unauthorized : Invalid Webhook Token + logging.warning( + "Unable to modify original response, send message instead", exc_info=True + ) + self.add_message(content, lambda msg: setattr(self, "last_response", msg)) + else: + raise err + + queue.add( + QueueItem( + "Edit response", + self.ctx.edit_original_response, + self.loop, + None, + on_error, + content=content, + ) + ) + else: + queue.add( + QueueItem( + "Send followup as response", + self.ctx.response.send_message, + self.loop, + content=content, + ) + ) + queue.add( QueueItem("response_message", super().response_message, self.loop, content=content) ) diff --git a/src/pyramid/data/guild_data.py b/src/pyramid/data/guild_data.py index 9633e63..144818f 100644 --- a/src/pyramid/data/guild_data.py +++ b/src/pyramid/data/guild_data.py @@ -1,14 +1,13 @@ -from typing import Dict from discord import Guild, VoiceClient -from data.a_search import ASearch from data.tracklist import TrackList +from data.functional.engine_source import EngineSource class GuildData: - def __init__(self, guild: Guild, search_engines: Dict[str, ASearch]): + def __init__(self, guild: Guild, engine_source: EngineSource): self.guild: Guild = guild self.track_list: TrackList = TrackList() self.voice_client: VoiceClient = None # type: ignore - self.search_engines = search_engines - self.search_engine = self.search_engines["deezer"] + self.search_engines = engine_source + self.search_engine = engine_source diff --git a/src/pyramid/test_queue.py b/src/pyramid/test_queue.py index 27a24f8..bb333d5 100644 --- a/src/pyramid/test_queue.py +++ b/src/pyramid/test_queue.py @@ -6,8 +6,8 @@ from tools.queue import Queue, QueueItem -THREADS_USE = 1 -TIMES = 1 +THREADS_USE = 20 +TIMES = 100_000 ITEM_DIFFUCULTY = 100 @@ -27,20 +27,37 @@ def factorial(n): def pi(n): start = time.time() - for i in range(0, n): - for x in range(1, 1000): - 3.141592 * 2**x - for x in range(1, 10000): - float(x) / 3.141592 - for x in range(1, 10000): - float(3.141592) / x + q, r, t, k, m, x = 1, 0, 1, 1, 3, 3 + decimal_output = "" + + for _ in range(n): + if 4*q+r-t < m*t: + decimal_output += str(m) + if len(decimal_output) == 1: + decimal_output += '.' + nr = 10*(r - m*t) + m = ((10*(3*q+r))//t)-10*m + q *= 10 + r = nr + decimal_output += str(m) + t = t + else: + nr = (2*q+r)*x + nn = (7*q*k+2+r*x)//(x*t) + q *= k + t *= x + x += 2 + k += 1 + m = nn + r = nr + + # print("PI", decimal_output) end = time.time() duration = end - start duration = round(duration, 3) return duration - results = [] * TIMES q = Queue(THREADS_USE) @@ -48,37 +65,50 @@ def pi(n): # Add functions to the queue for item in range(TIMES): # q.add(QueueItem(item, factorial, lambda result: results.append(result), n=ITEM_DIFFUCULTY)) - q.add(QueueItem(item, pi, None, lambda result: results.append(result), n=ITEM_DIFFUCULTY)) + q.add( + QueueItem( + item, + pi, + None, + lambda result: results.append(result), + n=ITEM_DIFFUCULTY, + ) + ) + +print(f"Starting {TIMES} times in {THREADS_USE} threads with difficulty {ITEM_DIFFUCULTY}.") +q.start() +print("All thread started.") start_time = time.time() -print(f"Starting {TIMES} times in {THREADS_USE} threads.") -q.start() - exit_handler_executed = False +# print("exit_handler") +q.end() def exit_handler(): - print("exit_handler") - q.end() - print("join") + global exit_handler_executed + if exit_handler_executed: + return + + # print("join") q.join() - print("join exit") + # print("join exit") end_time = time.time() time_difference = end_time - start_time # print("\n".join(result)) - print(f"Execute in {time_difference} seconds, {len(results)} times in {THREADS_USE} threads.") + print(f"Execute in {time_difference:.2f} seconds, {len(results)} times in {THREADS_USE} threads.") average_benchmark = round(sum(results) / TIMES, 3) - print(f"Average score (from {TIMES} repeats): {str(average_benchmark)}s") + print(f"Average score : {str(average_benchmark)}s") - global exit_handler_executed exit_handler_executed = True atexit.register(exit_handler) +exit_handler() -while not exit_handler_executed: - time.sleep(1) +# while not exit_handler_executed: +# time.sleep(1) diff --git a/src/pyramid/tools/configuration.py b/src/pyramid/tools/configuration.py index fbbfc8b..a01d139 100644 --- a/src/pyramid/tools/configuration.py +++ b/src/pyramid/tools/configuration.py @@ -19,28 +19,26 @@ def __init__(self): def load(self): config_file_path = "config.yml" - with open(config_file_path, "r") as config_file: - config_data = yaml.safe_load(config_file) + try: + with open(config_file_path, "r") as config_file: + config_data: dict = yaml.safe_load(config_file) + except FileNotFoundError as err: + raise err - self.deezer_arl = config_data["deezer"]["arl"] - self.deezer_folder = config_data["deezer"]["folder"] - self.discord_token = config_data["discord"]["token"] + self.deezer_arl = config_data.get("deezer", {}).get("arl", "") + self.deezer_folder = config_data.get("deezer", {}).get("folder", "") + self.discord_token = config_data.get("discord", {}).get("token", "") - ffmpeg = config_data["discord"]["ffmpeg"] + ffmpeg = config_data.get("discord", {}).get("ffmpeg", "") if not os.path.exists(ffmpeg): - raise Exception(f"Binary {ffmpeg} didn't exists on this system") + raise Exception(f"Binary {ffmpeg} doesn't exist on this system") self.discord_ffmpeg = ffmpeg - self.spotify_client_id = config_data["spotify"]["client_id"] - self.spotify_client_secret = config_data["spotify"]["client_secret"] - self.spotify_client_secret = config_data["spotify"]["client_secret"] + self.spotify_client_id = config_data.get("spotify", {}).get("client_id", "") + self.spotify_client_secret = config_data.get("spotify", {}).get("client_secret", "") - mode_upper = str(config_data["mode"]).replace("-", "_").upper() - for mode in Environment: - if mode.name == mode_upper: - self.mode = mode - break - else: - self.mode = Environment.PRODUCTION + mode_upper = str(config_data.get("mode", "")).replace("-", "_").upper() + self.mode = Environment[mode_upper] if mode_upper in Environment.__members__ else Environment.PRODUCTION - self.config_version = config_data["version"] + self.config_version = config_data.get("version", "") + self.default_limit_track = int(config_data.get("general", {}).get("default_limit_tracks", 0)) diff --git a/src/pyramid/tools/queue.py b/src/pyramid/tools/queue.py index 1b4bce4..d074cd4 100644 --- a/src/pyramid/tools/queue.py +++ b/src/pyramid/tools/queue.py @@ -28,7 +28,9 @@ def __init__( def worker(q: Deque[QueueItem], thread_id: int, lock: Lock, event: Event): while True: - event.wait() + if len(q) == 0: + event.wait() + if not q: event.clear() continue @@ -39,29 +41,7 @@ def worker(q: Deque[QueueItem], thread_id: int, lock: Lock, event: Event): break try: - # Async func - if inspect.iscoroutinefunction(item.func) or inspect.isasyncgenfunction(item.func): - # Async func in loop - if item.loop is not None: - # Async func in loop closed - if item.loop.is_closed(): - logging.warning( - "Exception in thread %d :\nUnable to call %s.%s cause the loop is closed", - thread_id, - item.func.__module__, - item.func.__qualname__, - ) - continue - # Async func in loop open - result = asyncio.run_coroutine_threadsafe( - item.func(**item.kwargs), item.loop - ).result() - # Async func classic - else: - result = asyncio.run(item.func(**item.kwargs)) - # Sync func - else: - result = item.func(**item.kwargs) + result = run_task(item.func, item.loop, **item.kwargs) if item.func_sucess is not None: item.func_sucess(result) @@ -83,6 +63,28 @@ def worker(q: Deque[QueueItem], thread_id: int, lock: Lock, event: Event): "".join(traceback.format_exception(type(err), err, err.__traceback__)), ) +def run_task(func: Callable, loop: asyncio.AbstractEventLoop | None, **kwargs): + # Async func + if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): + # Async func in loop + if loop is not None: + # Async func in loop closed + if loop.is_closed(): + raise Exception("Unable to call %s.%s cause the loop is closed", + func.__module__, + func.__qualname__ + ) + # Async func in loop open + result = asyncio.run_coroutine_threadsafe( + func(**kwargs), loop + ).result() + # Async func classic + else: + result = asyncio.run(func(**kwargs)) + # Sync func + else: + result = func(**kwargs) + return result class Queue: all_queue = deque() @@ -95,11 +97,6 @@ def __init__(self, threads=1, name=None): self.__lock = Lock() self.__worker = worker - if name is None: - name = "Thread" - else: - name = f"Thread {name}" - for thread_id in range(1, self.__threads + 1): thread = Thread( name=f"{name} n°{thread_id}", @@ -138,6 +135,9 @@ def end(self): for _ in range(self.__threads): self.add(None) + def length(self): + return len(self.__queue) + @staticmethod def wait_for_end(timeout_per_threads: float | None = None): for queue in Queue.all_queue: diff --git a/src/pyramid/tools/test_dev.py b/src/pyramid/tools/test_dev.py index 7886566..209ca50 100644 --- a/src/pyramid/tools/test_dev.py +++ b/src/pyramid/tools/test_dev.py @@ -13,7 +13,9 @@ def __init__(self, config: Configuration, logger: Logger): def test_spotify(self, input): spotify_search = SpotifySearch( - self._config.spotify_client_id, self._config.spotify_client_secret + self._config.default_limit_track, + self._config.spotify_client_id, + self._config.spotify_client_secret, ) res = spotify_search.search_tracks(input, limit=10) if res is None: @@ -22,7 +24,7 @@ def test_spotify(self, input): self.logger.info(track) def test_deezer(self, input): - deezer_search = DeezerSearch() + deezer_search = DeezerSearch(self._config.default_limit_track) res = deezer_search.search_tracks(input, limit=10) if res is None: return