diff --git a/download.py b/download.py
new file mode 100644
index 0000000..26b4df1
--- /dev/null
+++ b/download.py
@@ -0,0 +1,224 @@
+#!/usr/bin/env python3
+"""Various mtdata dataset downloading utilities"""
+import os
+from glob import iglob
+from itertools import chain
+from typing import Iterable, Dict, List, Optional, Set
+from enum import Enum
+from queue import SimpleQueue
+from subprocess import Popen
+from threading import Thread
+from collections import defaultdict
+
+import mtdata.entry
+from mtdata.entry import lang_pair
+from mtdata.index import Index, get_entries
+from mtdata.iso.bcp47 import bcp47, BCP47Tag
+from pydantic import BaseModel
+from fastapi import FastAPI, HTTPException
+
+
+DATA_PATH = os.getenv('DATA_PATH', 'data/train-parts/*.*.gz')
+
+DOWNLOAD_PATH = 'data'
+
+
+class EntryRef(BaseModel):
+ id: str
+
+
+class Entry(EntryRef):
+ group: str
+ name: str
+ version: str
+ langs: List[str]
+ paths: Set[str]
+
+
+class DownloadState(Enum):
+ PENDING = 'pending'
+ CANCELLED = 'cancelled'
+ DOWNLOADING = 'downloading'
+ DOWNLOADED = 'downloaded'
+ FAILED = 'failed'
+
+
+def get_dataset(entry: Entry, path: str) -> Popen:
+ """Gets datasets, using a subprocess call to mtdata. Might be less brittle to internal
+ mtdata interface changes"""
+ call_type: str # We have train/test/dev
+ if "dev" in entry.name:
+ call_type = "-dv"
+ elif "test" in entry.name:
+ call_type = "-ts"
+ else:
+ call_type = "-tr"
+
+ # Download the dataset
+ return Popen(["mtdata", "get", "-l", "-".join(entry.langs), call_type, entry.id, "--compress", "-o", path])
+
+
+class EntryDownload:
+ entry: Entry
+
+ def __init__(self, entry:Entry):
+ self.entry = entry
+ self._child = None
+
+ def start(self):
+ self._child = get_dataset(self.entry, DOWNLOAD_PATH)
+
+ def cancel(self):
+ if self._child and self._child.returncode is None:
+ self._child.kill()
+
+ @property
+ def state(self):
+ if not self._child:
+ return DownloadState.PENDING
+ elif self._child.returncode is None:
+ return DownloadState.DOWNLOADING
+ elif self._child.returncode == 0:
+ return DownloadState.DOWNLOADED
+ elif self._child.returncode > 0:
+ return DownloadState.FAILED
+ else:
+ return DownloadState.CANCELLED
+
+
+class Downloader:
+ def __init__(self, workers:int):
+ self.queue = SimpleQueue()
+ self.threads = []
+
+ for _ in range(workers):
+ thread = Thread(target=self.__class__.worker_thread, args=[self.queue], daemon=True)
+ thread.start()
+ self.threads.append(thread)
+
+ def download(self, entry:Entry) -> EntryDownload:
+ download = EntryDownload(entry=entry)
+ self.queue.put(download)
+ return download
+
+ @staticmethod
+ def worker_thread(queue):
+ while True:
+ entry = queue.get()
+ if not entry:
+ break
+ entry.start()
+ entry._child.wait()
+
+
+class EntryDownloadView(BaseModel):
+ entry: Entry
+ state: DownloadState
+
+
+app = FastAPI()
+
+downloads: Dict[str, EntryDownload] = {}
+
+downloader = Downloader(2)
+
+
+def find_local_paths(entry: mtdata.entry.Entry) -> Set[str]:
+ return set(
+ filename
+ for data_root in [os.path.dirname(DATA_PATH), DOWNLOAD_PATH]
+ for lang in entry.did.langs
+ for filename in iglob(os.path.join(data_root, f'{entry.did!s}.{lang.lang}.gz'), recursive=True)
+ )
+
+
+def cast_entry(entry) -> Entry:
+ return Entry(
+ id = str(entry.did),
+ group = entry.did.group,
+ name = entry.did.name,
+ version = entry.did.version,
+ langs = [lang.lang for lang in entry.did.langs],
+ paths = find_local_paths(entry)
+ )
+
+
+@app.get("/languages/")
+@app.get("/languages/{lang1}")
+def list_languages(lang1:Optional[str] = None) -> Iterable[str]:
+ langs: set[str] = set()
+ filter_lang = bcp47(lang1) if lang1 is not None else None
+ for entry in Index.get_instance().get_entries():
+ if filter_lang is not None and filter_lang not in entry.did.langs:
+ continue
+ langs.update(*entry.did.langs)
+ return sorted(lang for lang in langs if lang is not None)
+
+
+@app.get("/by-language/{langs}")
+def list_datasets(langs:str) -> Iterable[Entry]:
+ return dedupe_datasests(
+ cast_entry(entry) for entry in get_entries(lang_pair(langs))
+ )
+
+
+@app.get('/downloads/')
+def list_downloads() -> Iterable[EntryDownloadView]:
+ return (
+ EntryDownloadView(
+ entry = download.entry,
+ state = download.state
+ )
+ for download in downloads.values()
+ )
+
+
+@app.post('/downloads/')
+def batch_add_downloads(datasets: List[EntryRef]) -> Iterable[EntryDownloadView]:
+ """Batch download requests!"""
+ needles = set(dataset.id
+ for dataset in datasets
+ if dataset.id not in downloads)
+
+ entries = [
+ cast_entry(entry)
+ for entry in Index.get_instance().get_entries()
+ if str(entry.did) in needles
+ ]
+
+ for entry in entries:
+ downloads[entry.id] = downloader.download(entry)
+
+ return list_downloads()
+
+
+@app.delete('/downloads/{dataset_id}')
+def cancel_download(dataset_id:str) -> EntryDownloadView:
+ """Cancel a download. Removes it from the queue, does not kill the process
+ if download is already happening.
+ """
+ if dataset_id not in downloads:
+ raise HTTPException(status_code=404, detail='Download not found')
+
+ download = downloads[dataset_id]
+ download.cancel()
+
+ return EntryDownloadView(
+ entry = download.entry,
+ state = download.state
+ )
+
+
+def dedupe_datasests(datasets: Iterable[Entry]) -> Iterable[Entry]:
+ """Mtdata contains a multitude a datasets that have many different versions
+ (eg europarl). In the vast majority of the cases we ONLY EVER want the
+ latest version
+ """
+ datadict: Dict[str, List[Entry]] = defaultdict(list)
+ for entry in datasets:
+ datadict[entry.name].append(entry)
+ # Sort by version and return one per name
+ return [
+ sorted(entrylist, key=lambda t: t.version, reverse=True)[0]
+ for entrylist in datadict.values()
+ ]
diff --git a/frontend/src/interval.js b/frontend/src/interval.js
new file mode 100644
index 0000000..685a8c7
--- /dev/null
+++ b/frontend/src/interval.js
@@ -0,0 +1,37 @@
+/**
+ * Little wrapper that emulates setInterval() but with methods, and
+ * takes the time callback() takes to resolve into account.
+ */
+export class Interval {
+ #timeout;
+
+ constructor(interval, callback) {
+ this.interval = interval;
+ this.callback = callback;
+ }
+
+ start() {
+ if (this.#timeout)
+ clearTimeout(this.#timeout);
+ this.#timeout = setTimeout(this.#callback.bind(this), this.interval);
+ }
+
+ stop() {
+ clearTimeout(this.#timeout);
+ this.#timeout = null;
+ }
+
+ restart() {
+ this.stop();
+ this.start();
+ }
+
+ #callback() {
+ // Wait for the callback() to resolve in case it is
+ // async, and then schedule a new call.
+ Promise.resolve(this.callback()).then(() => {
+ if (this.#timeout)
+ this.start()
+ })
+ }
+}
\ No newline at end of file
diff --git a/frontend/src/router/index.js b/frontend/src/router/index.js
index 1e54e96..05148b0 100644
--- a/frontend/src/router/index.js
+++ b/frontend/src/router/index.js
@@ -12,7 +12,12 @@ const router = createRouter({
path: '/datasets/:datasetName/configuration',
name: 'edit-filters',
component: () => import('../views/EditFiltersView.vue')
- }
+ },
+ {
+ path: '/download/',
+ name: 'add-dataset',
+ component: () => import('../views/AddDatasetView.vue')
+ },
]
})
diff --git a/frontend/src/views/AddDatasetView.vue b/frontend/src/views/AddDatasetView.vue
new file mode 100644
index 0000000..3a3e6ac
--- /dev/null
+++ b/frontend/src/views/AddDatasetView.vue
@@ -0,0 +1,214 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
{{ dataset.name }}
+
{{ dataset.group }}
+
{{ dataset.version }}
+
{{ dataset.langs.join(', ') }}
+
+
+
+
+
Downloads
+
+
{{ download.entry.name }} {{ download.state }}
+
+
Shopping cart
+
+
{{ dataset.name }}
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/frontend/src/views/ListDatasetsView.vue b/frontend/src/views/ListDatasetsView.vue
index dded256..3317efd 100644
--- a/frontend/src/views/ListDatasetsView.vue
+++ b/frontend/src/views/ListDatasetsView.vue
@@ -1,5 +1,6 @@