diff --git a/download.py b/download.py new file mode 100644 index 0000000..26b4df1 --- /dev/null +++ b/download.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 +"""Various mtdata dataset downloading utilities""" +import os +from glob import iglob +from itertools import chain +from typing import Iterable, Dict, List, Optional, Set +from enum import Enum +from queue import SimpleQueue +from subprocess import Popen +from threading import Thread +from collections import defaultdict + +import mtdata.entry +from mtdata.entry import lang_pair +from mtdata.index import Index, get_entries +from mtdata.iso.bcp47 import bcp47, BCP47Tag +from pydantic import BaseModel +from fastapi import FastAPI, HTTPException + + +DATA_PATH = os.getenv('DATA_PATH', 'data/train-parts/*.*.gz') + +DOWNLOAD_PATH = 'data' + + +class EntryRef(BaseModel): + id: str + + +class Entry(EntryRef): + group: str + name: str + version: str + langs: List[str] + paths: Set[str] + + +class DownloadState(Enum): + PENDING = 'pending' + CANCELLED = 'cancelled' + DOWNLOADING = 'downloading' + DOWNLOADED = 'downloaded' + FAILED = 'failed' + + +def get_dataset(entry: Entry, path: str) -> Popen: + """Gets datasets, using a subprocess call to mtdata. Might be less brittle to internal + mtdata interface changes""" + call_type: str # We have train/test/dev + if "dev" in entry.name: + call_type = "-dv" + elif "test" in entry.name: + call_type = "-ts" + else: + call_type = "-tr" + + # Download the dataset + return Popen(["mtdata", "get", "-l", "-".join(entry.langs), call_type, entry.id, "--compress", "-o", path]) + + +class EntryDownload: + entry: Entry + + def __init__(self, entry:Entry): + self.entry = entry + self._child = None + + def start(self): + self._child = get_dataset(self.entry, DOWNLOAD_PATH) + + def cancel(self): + if self._child and self._child.returncode is None: + self._child.kill() + + @property + def state(self): + if not self._child: + return DownloadState.PENDING + elif self._child.returncode is None: + return DownloadState.DOWNLOADING + elif self._child.returncode == 0: + return DownloadState.DOWNLOADED + elif self._child.returncode > 0: + return DownloadState.FAILED + else: + return DownloadState.CANCELLED + + +class Downloader: + def __init__(self, workers:int): + self.queue = SimpleQueue() + self.threads = [] + + for _ in range(workers): + thread = Thread(target=self.__class__.worker_thread, args=[self.queue], daemon=True) + thread.start() + self.threads.append(thread) + + def download(self, entry:Entry) -> EntryDownload: + download = EntryDownload(entry=entry) + self.queue.put(download) + return download + + @staticmethod + def worker_thread(queue): + while True: + entry = queue.get() + if not entry: + break + entry.start() + entry._child.wait() + + +class EntryDownloadView(BaseModel): + entry: Entry + state: DownloadState + + +app = FastAPI() + +downloads: Dict[str, EntryDownload] = {} + +downloader = Downloader(2) + + +def find_local_paths(entry: mtdata.entry.Entry) -> Set[str]: + return set( + filename + for data_root in [os.path.dirname(DATA_PATH), DOWNLOAD_PATH] + for lang in entry.did.langs + for filename in iglob(os.path.join(data_root, f'{entry.did!s}.{lang.lang}.gz'), recursive=True) + ) + + +def cast_entry(entry) -> Entry: + return Entry( + id = str(entry.did), + group = entry.did.group, + name = entry.did.name, + version = entry.did.version, + langs = [lang.lang for lang in entry.did.langs], + paths = find_local_paths(entry) + ) + + +@app.get("/languages/") +@app.get("/languages/{lang1}") +def list_languages(lang1:Optional[str] = None) -> Iterable[str]: + langs: set[str] = set() + filter_lang = bcp47(lang1) if lang1 is not None else None + for entry in Index.get_instance().get_entries(): + if filter_lang is not None and filter_lang not in entry.did.langs: + continue + langs.update(*entry.did.langs) + return sorted(lang for lang in langs if lang is not None) + + +@app.get("/by-language/{langs}") +def list_datasets(langs:str) -> Iterable[Entry]: + return dedupe_datasests( + cast_entry(entry) for entry in get_entries(lang_pair(langs)) + ) + + +@app.get('/downloads/') +def list_downloads() -> Iterable[EntryDownloadView]: + return ( + EntryDownloadView( + entry = download.entry, + state = download.state + ) + for download in downloads.values() + ) + + +@app.post('/downloads/') +def batch_add_downloads(datasets: List[EntryRef]) -> Iterable[EntryDownloadView]: + """Batch download requests!""" + needles = set(dataset.id + for dataset in datasets + if dataset.id not in downloads) + + entries = [ + cast_entry(entry) + for entry in Index.get_instance().get_entries() + if str(entry.did) in needles + ] + + for entry in entries: + downloads[entry.id] = downloader.download(entry) + + return list_downloads() + + +@app.delete('/downloads/{dataset_id}') +def cancel_download(dataset_id:str) -> EntryDownloadView: + """Cancel a download. Removes it from the queue, does not kill the process + if download is already happening. + """ + if dataset_id not in downloads: + raise HTTPException(status_code=404, detail='Download not found') + + download = downloads[dataset_id] + download.cancel() + + return EntryDownloadView( + entry = download.entry, + state = download.state + ) + + +def dedupe_datasests(datasets: Iterable[Entry]) -> Iterable[Entry]: + """Mtdata contains a multitude a datasets that have many different versions + (eg europarl). In the vast majority of the cases we ONLY EVER want the + latest version + """ + datadict: Dict[str, List[Entry]] = defaultdict(list) + for entry in datasets: + datadict[entry.name].append(entry) + # Sort by version and return one per name + return [ + sorted(entrylist, key=lambda t: t.version, reverse=True)[0] + for entrylist in datadict.values() + ] diff --git a/frontend/src/interval.js b/frontend/src/interval.js new file mode 100644 index 0000000..685a8c7 --- /dev/null +++ b/frontend/src/interval.js @@ -0,0 +1,37 @@ +/** + * Little wrapper that emulates setInterval() but with methods, and + * takes the time callback() takes to resolve into account. + */ +export class Interval { + #timeout; + + constructor(interval, callback) { + this.interval = interval; + this.callback = callback; + } + + start() { + if (this.#timeout) + clearTimeout(this.#timeout); + this.#timeout = setTimeout(this.#callback.bind(this), this.interval); + } + + stop() { + clearTimeout(this.#timeout); + this.#timeout = null; + } + + restart() { + this.stop(); + this.start(); + } + + #callback() { + // Wait for the callback() to resolve in case it is + // async, and then schedule a new call. + Promise.resolve(this.callback()).then(() => { + if (this.#timeout) + this.start() + }) + } +} \ No newline at end of file diff --git a/frontend/src/router/index.js b/frontend/src/router/index.js index 1e54e96..05148b0 100644 --- a/frontend/src/router/index.js +++ b/frontend/src/router/index.js @@ -12,7 +12,12 @@ const router = createRouter({ path: '/datasets/:datasetName/configuration', name: 'edit-filters', component: () => import('../views/EditFiltersView.vue') - } + }, + { + path: '/download/', + name: 'add-dataset', + component: () => import('../views/AddDatasetView.vue') + }, ] }) diff --git a/frontend/src/views/AddDatasetView.vue b/frontend/src/views/AddDatasetView.vue new file mode 100644 index 0000000..3a3e6ac --- /dev/null +++ b/frontend/src/views/AddDatasetView.vue @@ -0,0 +1,214 @@ + + + + + \ No newline at end of file diff --git a/frontend/src/views/ListDatasetsView.vue b/frontend/src/views/ListDatasetsView.vue index dded256..3317efd 100644 --- a/frontend/src/views/ListDatasetsView.vue +++ b/frontend/src/views/ListDatasetsView.vue @@ -1,5 +1,6 @@