diff --git a/src/fastspell/ai.py b/src/fastspell/ai.py new file mode 100644 index 0000000..9136375 --- /dev/null +++ b/src/fastspell/ai.py @@ -0,0 +1,200 @@ +import os +import io +import sys +import fasttext +import hunspell +import logging +import urllib.request +import pathlib +import timeit +import argparse +import traceback +import logging +from sklearn.feature_extraction import DictVectorizer +import pycountry +import xgboost +from sklearn.model_selection import GridSearchCV + +try: + from . import __version__ + from .util import logging_setup, remove_unwanted_words, get_hash, check_dir, load_config + from .fastspell import FastSpell +except ImportError: + from fastspell import __version__ + from util import logging_setup, remove_unwanted_words, get_hash, check_dir, load_config + from fastspell import FastSpell + +class FastSpellAI(FastSpell): + def __init__(self, lang, *args, **kwargs): + super().__init__(lang, *args, **kwargs) + +ft_download_url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" +ft_model_path = "lid.176.bin" +if os.path.exists(ft_model_path): + ft_model = fasttext.load_model(ft_model_path) +else: + urllib.request.urlretrieve(ft_download_url, ft_model_path) + ft_model = fasttext.load_model(ft_model_path) + +ft_prefix = "__label__" + +fsobj = FastSpellAI("en") + +languages = [label[len(ft_prefix):] for label in fsobj.model.get_labels()] + +unsupported = [] +hunspell_objs = {} +for language in languages: + try: + search_l = None + if language in fsobj.hunspell_codes: + search_l = fsobj.hunspell_codes[language] + elif f"{language}_lat" in fsobj.hunspell_codes: + search_l = fsobj.hunspell_codes[f"{language}_lat"] + elif f"{language}_cyr" in fsobj.hunspell_codes: + search_l = fsobj.hunspell_codes[f"{language}_cyr"] + else: + search_l = language + hunspell_objs[language] = fsobj.search_hunspell_dict(search_l) + except: + unsupported.append(language) + +print(len(languages)) +print(len(unsupported)) +print(unsupported) + +prediction = fsobj.model.predict("Ciao, mondo!".lower(), k=3) +print(prediction) +print(prediction[0]) +print(prediction[0][0]) +print(prediction[0][0][len(ft_prefix):]) + +sentences = [] +labels = [] +count = 0 +with open("../sentences.csv", "r") as f: + for l in f: + number, language, text = next(f).split("\t") + + if language != "ita": + continue + + lang = pycountry.languages.get(alpha_3=language) + + text = text.replace("\n", " ").strip() + prediction = fsobj.model.predict(text.lower(), k=3) + + # print(prediction) + + lang0 = prediction[0][0][len(ft_prefix):] + lang0_prob = prediction[1][0] + if len(prediction[0]) >= 2: + lang1 = prediction[0][1][len(ft_prefix):] + lang1_prob = prediction[1][1] + else: + # If there's only one option... Not much to do. + continue + if len(prediction[0]) >= 3: + lang2 = prediction[0][2][len(ft_prefix):] + lang2_prob = prediction[1][2] + else: + lang2 = None + lang2_prob = 0.0 + + label = None + if lang0 == lang.alpha_2: + label = 0 + elif lang1 == lang.alpha_2: + label = 1 + elif lang2 == lang.alpha_2: + label = 2 + + if label is None: + continue + + # print(lang0) + + raw_tokens = text.strip().split(" ") + if lang0 in hunspell_objs: + tokens = remove_unwanted_words(raw_tokens, lang0) + correct = 0 + for token in tokens: + try: + if hunspell_objs[lang0].spell(token): + correct += 1 + except UnicodeEncodeError as ex: + pass + lang0_dic_tokens = correct / len(tokens) + else: + lang0_dic_tokens = None + + if lang1 in hunspell_objs: + tokens = remove_unwanted_words(raw_tokens, lang1) + correct = 0 + for token in tokens: + try: + if hunspell_objs[lang1].spell(token): + correct += 1 + except UnicodeEncodeError as ex: + pass + lang1_dic_tokens = correct / len(tokens) + else: + lang1_dic_tokens = None + + if lang2 in hunspell_objs: + tokens = remove_unwanted_words(raw_tokens, lang2) + + correct = 0 + for token in tokens: + try: + if hunspell_objs[lang2].spell(token): + correct += 1 + except UnicodeEncodeError as ex: + pass + + lang2_dic_tokens = correct / len(tokens) + else: + lang2_dic_tokens = None + + sentences.append({ + "fastText_lang0": lang0_prob, + "fastText_lang1": lang1_prob, + "fastText_lang2": lang2_prob, + "lang0_dic_tokens": lang0_dic_tokens, + "lang1_dic_tokens": lang1_dic_tokens, + "lang2_dic_tokens": lang2_dic_tokens, + }) + labels.append(label) + + # count += 1 + # if count == 7: + # break + +print(len(sentences)) + +dict_vectorizer = DictVectorizer() +X = dict_vectorizer.fit_transform(sentences) + +xgb_model = xgboost.XGBClassifier(n_jobs=10) + +clf = GridSearchCV( + xgb_model, + {"max_depth": [1, 2, 4, 6], "n_estimators": [25, 50, 100, 200]}, + verbose=1, + n_jobs=1, +) +clf.fit(X, labels) +print(clf.best_score_) +print(clf.best_params_) +print(clf.best_estimator_) + +clf.best_estimator_.save_model("model.ubj") + +X_try = dict_vectorizer.fit_transform([sentences[0]]) +classes = xgb_model.predict(X) +if classes[0] == 0: + print("Lang0 chosen") +elif classes[0] == 1: + print("Lang2 chosen") +elif classes[0] == 2: + print("Lang3 chosen") diff --git a/src/fastspell/config/hunspell.yaml b/src/fastspell/config/hunspell.yaml index 4eb1e0f..da53768 100644 --- a/src/fastspell/config/hunspell.yaml +++ b/src/fastspell/config/hunspell.yaml @@ -18,6 +18,7 @@ hunspell_codes: da: da_DK de: de_DE en: en_GB + el: el_GR es: es_ES et: et_ET fa: fa_IR @@ -75,3 +76,4 @@ hunspell_codes: ur: ur_PK uz: uz_UZ yi: yi + vi: vi_VN diff --git a/src/fastspell/fastspell.py b/src/fastspell/fastspell.py index 0a39921..97cd5e6 100644 --- a/src/fastspell/fastspell.py +++ b/src/fastspell/fastspell.py @@ -24,6 +24,9 @@ HBS_LANGS = ('hbs', 'sh', 'bs', 'sr', 'hr', 'me') +# logger = logging.getLogger() +# logger.setLevel(logging.DEBUG) + def initialization(): parser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0]), formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=__doc__) parser.add_argument('lang', type=str) @@ -94,6 +97,8 @@ def download_fasttext(self): def search_hunspell_dict(self, lang_code): ''' Search in the paths for a hunspell dictionary and load it ''' + hunspell_obj = None + for p in self.hunspell_paths: if os.path.exists(f"{p}/{lang_code}.dic") and os.path.exists(f"{p}/{lang_code}.aff"): try: @@ -105,10 +110,36 @@ def search_hunspell_dict(self, lang_code): logging.error("Failed building Hunspell object for " + lang_code) logging.error("Aborting.") exit(1) - else: - raise RuntimeError(f"It does not exist any valid dictionary directory" - f"for {lang_code} in the paths {self.hunspell_paths}." + + if hunspell_obj is None: + for p in self.hunspell_paths: + if not os.path.exists(p): + continue + + potential_files = [path for path in os.listdir(p) if os.path.basename(path).startswith(lang_code)] + if f"{lang_code}.dic" in potential_files and f"{lang_code}.aff" in potential_files: + dic = lang_code + elif f"{lang_code}_{lang_code.upper()}.dic" in potential_files and f"{lang_code}_{lang_code.upper()}.aff" in potential_files: + dic = f"{lang_code}_{lang_code.upper()}" + elif len(potential_files) == 2: + dic = potential_files[0][:-4] + else: + continue + + try: + hunspell_obj = hunspell.Hunspell(dic, hunspell_data_dir=p) + logging.debug(f"Loaded hunspell obj for '{lang_code}' in path: {p + '/' + dic}") + break + except: + logging.error("Failed building Hunspell object for " + dic) + logging.error("Aborting.") + exit(1) + + if hunspell_obj is None: + raise RuntimeError(f"It does not exist any valid dictionary directory " + f"for {lang_code} in the paths {self.hunspell_paths}. " f"Please, execute 'fastspell-download'.") + return hunspell_obj @@ -127,7 +158,7 @@ def load_hunspell_dicts(self): self.similar = [] for sim_entry in self.similar_langs: if sim_entry.split('_')[0] == self.lang: - self.similar.append(self.similar_langs[sim_entry]) + self.similar.append(self.similar_langs[sim_entry] + [self.lang]) logging.debug(f"Similar lists for '{self.lang}': {self.similar}") self.hunspell_objs = {} @@ -208,6 +239,9 @@ def getlang(self, sent): #TODO: Confidence score? + logging.debug(prediction) + logging.debug(self.similar) + if self.similar == [] or prediction not in self.hunspell_objs: #Non mistakeable language: just return FastText prediction refined_prediction = prediction @@ -218,6 +252,7 @@ def getlang(self, sent): for sim_list in self.similar: if prediction in sim_list or f'{prediction}_{script}' in sim_list: current_similar = sim_list + logging.debug(current_similar) spellchecked = {} for l in current_similar: