-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test basic AI model #17
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
import os | ||
import io | ||
import sys | ||
import fasttext | ||
import hunspell | ||
import logging | ||
import urllib.request | ||
import pathlib | ||
import timeit | ||
import argparse | ||
import traceback | ||
import logging | ||
from sklearn.feature_extraction import DictVectorizer | ||
import pycountry | ||
import xgboost | ||
from sklearn.model_selection import GridSearchCV | ||
|
||
try: | ||
from . import __version__ | ||
from .util import logging_setup, remove_unwanted_words, get_hash, check_dir, load_config | ||
from .fastspell import FastSpell | ||
except ImportError: | ||
from fastspell import __version__ | ||
from util import logging_setup, remove_unwanted_words, get_hash, check_dir, load_config | ||
from fastspell import FastSpell | ||
|
||
class FastSpellAI(FastSpell): | ||
def __init__(self, lang, *args, **kwargs): | ||
super().__init__(lang, *args, **kwargs) | ||
|
||
ft_download_url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" | ||
ft_model_path = "lid.176.bin" | ||
if os.path.exists(ft_model_path): | ||
ft_model = fasttext.load_model(ft_model_path) | ||
else: | ||
urllib.request.urlretrieve(ft_download_url, ft_model_path) | ||
ft_model = fasttext.load_model(ft_model_path) | ||
|
||
ft_prefix = "__label__" | ||
|
||
fsobj = FastSpellAI("en") | ||
|
||
languages = [label[len(ft_prefix):] for label in fsobj.model.get_labels()] | ||
|
||
unsupported = [] | ||
hunspell_objs = {} | ||
for language in languages: | ||
try: | ||
search_l = None | ||
if language in fsobj.hunspell_codes: | ||
search_l = fsobj.hunspell_codes[language] | ||
elif f"{language}_lat" in fsobj.hunspell_codes: | ||
search_l = fsobj.hunspell_codes[f"{language}_lat"] | ||
elif f"{language}_cyr" in fsobj.hunspell_codes: | ||
search_l = fsobj.hunspell_codes[f"{language}_cyr"] | ||
else: | ||
search_l = language | ||
hunspell_objs[language] = fsobj.search_hunspell_dict(search_l) | ||
except: | ||
unsupported.append(language) | ||
|
||
print(len(languages)) | ||
print(len(unsupported)) | ||
print(unsupported) | ||
|
||
prediction = fsobj.model.predict("Ciao, mondo!".lower(), k=3) | ||
print(prediction) | ||
print(prediction[0]) | ||
print(prediction[0][0]) | ||
print(prediction[0][0][len(ft_prefix):]) | ||
|
||
sentences = [] | ||
labels = [] | ||
count = 0 | ||
with open("../sentences.csv", "r") as f: | ||
for l in f: | ||
number, language, text = next(f).split("\t") | ||
|
||
if language != "ita": | ||
continue | ||
|
||
lang = pycountry.languages.get(alpha_3=language) | ||
|
||
text = text.replace("\n", " ").strip() | ||
prediction = fsobj.model.predict(text.lower(), k=3) | ||
|
||
# print(prediction) | ||
|
||
lang0 = prediction[0][0][len(ft_prefix):] | ||
lang0_prob = prediction[1][0] | ||
if len(prediction[0]) >= 2: | ||
lang1 = prediction[0][1][len(ft_prefix):] | ||
lang1_prob = prediction[1][1] | ||
else: | ||
# If there's only one option... Not much to do. | ||
continue | ||
if len(prediction[0]) >= 3: | ||
lang2 = prediction[0][2][len(ft_prefix):] | ||
lang2_prob = prediction[1][2] | ||
else: | ||
lang2 = None | ||
lang2_prob = 0.0 | ||
|
||
label = None | ||
if lang0 == lang.alpha_2: | ||
label = 0 | ||
elif lang1 == lang.alpha_2: | ||
label = 1 | ||
elif lang2 == lang.alpha_2: | ||
label = 2 | ||
|
||
if label is None: | ||
continue | ||
|
||
# print(lang0) | ||
|
||
raw_tokens = text.strip().split(" ") | ||
if lang0 in hunspell_objs: | ||
tokens = remove_unwanted_words(raw_tokens, lang0) | ||
correct = 0 | ||
for token in tokens: | ||
try: | ||
if hunspell_objs[lang0].spell(token): | ||
correct += 1 | ||
except UnicodeEncodeError as ex: | ||
pass | ||
lang0_dic_tokens = correct / len(tokens) | ||
else: | ||
lang0_dic_tokens = None | ||
|
||
if lang1 in hunspell_objs: | ||
tokens = remove_unwanted_words(raw_tokens, lang1) | ||
correct = 0 | ||
for token in tokens: | ||
try: | ||
if hunspell_objs[lang1].spell(token): | ||
correct += 1 | ||
except UnicodeEncodeError as ex: | ||
pass | ||
lang1_dic_tokens = correct / len(tokens) | ||
else: | ||
lang1_dic_tokens = None | ||
|
||
if lang2 in hunspell_objs: | ||
tokens = remove_unwanted_words(raw_tokens, lang2) | ||
|
||
correct = 0 | ||
for token in tokens: | ||
try: | ||
if hunspell_objs[lang2].spell(token): | ||
correct += 1 | ||
except UnicodeEncodeError as ex: | ||
pass | ||
|
||
lang2_dic_tokens = correct / len(tokens) | ||
else: | ||
lang2_dic_tokens = None | ||
|
||
sentences.append({ | ||
"fastText_lang0": lang0_prob, | ||
"fastText_lang1": lang1_prob, | ||
"fastText_lang2": lang2_prob, | ||
"lang0_dic_tokens": lang0_dic_tokens, | ||
"lang1_dic_tokens": lang1_dic_tokens, | ||
"lang2_dic_tokens": lang2_dic_tokens, | ||
}) | ||
labels.append(label) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The label is 1 if the language to choose is the first among the three, 2 if it's the second, 3 if it's the third. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh sorry, I did not see this. Maybe this changes a little bit my proposal. |
||
|
||
# count += 1 | ||
# if count == 7: | ||
# break | ||
|
||
print(len(sentences)) | ||
|
||
dict_vectorizer = DictVectorizer() | ||
X = dict_vectorizer.fit_transform(sentences) | ||
|
||
xgb_model = xgboost.XGBClassifier(n_jobs=10) | ||
|
||
clf = GridSearchCV( | ||
xgb_model, | ||
{"max_depth": [1, 2, 4, 6], "n_estimators": [25, 50, 100, 200]}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could presumably introduce another hyperparameter, a "threshold" above which we don't need to go look at the dictionary (e.g. if fastText is 99% sure it's Italian, we don't need to check further). This could lower the cost of the approach. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's right. I've always been interested if there is a correlation between confidence and precision. But I really never had time to see if the number of false positives or false negatives is very small when confidence is high. Therefore we could add this exception and speed it up. |
||
verbose=1, | ||
n_jobs=1, | ||
) | ||
clf.fit(X, labels) | ||
print(clf.best_score_) | ||
print(clf.best_params_) | ||
print(clf.best_estimator_) | ||
|
||
clf.best_estimator_.save_model("model.ubj") | ||
|
||
X_try = dict_vectorizer.fit_transform([sentences[0]]) | ||
classes = xgb_model.predict(X) | ||
if classes[0] == 0: | ||
print("Lang0 chosen") | ||
elif classes[0] == 1: | ||
print("Lang2 chosen") | ||
elif classes[0] == 2: | ||
print("Lang3 chosen") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ hunspell_codes: | |
da: da_DK | ||
de: de_DE | ||
en: en_GB | ||
el: el_GR | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will be cleaned up, right? We prefer to add new languages to the default config if the come along with the corresponding fastspell-dictionaries update that adds them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this was just for testing. |
||
es: es_ES | ||
et: et_ET | ||
fa: fa_IR | ||
|
@@ -75,3 +76,4 @@ hunspell_codes: | |
ur: ur_PK | ||
uz: uz_UZ | ||
yi: yi | ||
vi: vi_VN |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,9 @@ | |
HBS_LANGS = ('hbs', 'sh', 'bs', 'sr', 'hr', 'me') | ||
|
||
|
||
# logger = logging.getLogger() | ||
# logger.setLevel(logging.DEBUG) | ||
|
||
def initialization(): | ||
parser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0]), formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=__doc__) | ||
parser.add_argument('lang', type=str) | ||
|
@@ -94,6 +97,8 @@ def download_fasttext(self): | |
|
||
def search_hunspell_dict(self, lang_code): | ||
''' Search in the paths for a hunspell dictionary and load it ''' | ||
hunspell_obj = None | ||
|
||
for p in self.hunspell_paths: | ||
if os.path.exists(f"{p}/{lang_code}.dic") and os.path.exists(f"{p}/{lang_code}.aff"): | ||
try: | ||
|
@@ -105,10 +110,36 @@ def search_hunspell_dict(self, lang_code): | |
logging.error("Failed building Hunspell object for " + lang_code) | ||
logging.error("Aborting.") | ||
exit(1) | ||
else: | ||
raise RuntimeError(f"It does not exist any valid dictionary directory" | ||
f"for {lang_code} in the paths {self.hunspell_paths}." | ||
|
||
if hunspell_obj is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is unrelated, I needed it to load as many dictionaries as possible from my system (as I was testing with Italian, and it is not available in fastspell-dictionaries). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this change could be useful in its own right, so users of fastspell can more easily load more dictionaries. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me know if you want to land this change (after cleaning it up of course) and I'll open a PR for it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if I'm not understanding this change correctly, but there's already a whole path search for possible dictionary candidates other than the fastspell-dictionaries if a user wants to use the system's. It's explained in the documentation and here is the code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was to allow loading a dictionary more easily, e.g. if "it" is the language and there's "it-IT.dic" in the system, it will assume that's the one to load (even though there might be others like "it_CH.dic"). |
||
for p in self.hunspell_paths: | ||
if not os.path.exists(p): | ||
continue | ||
|
||
potential_files = [path for path in os.listdir(p) if os.path.basename(path).startswith(lang_code)] | ||
if f"{lang_code}.dic" in potential_files and f"{lang_code}.aff" in potential_files: | ||
dic = lang_code | ||
elif f"{lang_code}_{lang_code.upper()}.dic" in potential_files and f"{lang_code}_{lang_code.upper()}.aff" in potential_files: | ||
dic = f"{lang_code}_{lang_code.upper()}" | ||
elif len(potential_files) == 2: | ||
dic = potential_files[0][:-4] | ||
else: | ||
continue | ||
|
||
try: | ||
hunspell_obj = hunspell.Hunspell(dic, hunspell_data_dir=p) | ||
logging.debug(f"Loaded hunspell obj for '{lang_code}' in path: {p + '/' + dic}") | ||
break | ||
except: | ||
logging.error("Failed building Hunspell object for " + dic) | ||
logging.error("Aborting.") | ||
exit(1) | ||
|
||
if hunspell_obj is None: | ||
raise RuntimeError(f"It does not exist any valid dictionary directory " | ||
f"for {lang_code} in the paths {self.hunspell_paths}. " | ||
f"Please, execute 'fastspell-download'.") | ||
|
||
return hunspell_obj | ||
|
||
|
||
|
@@ -127,7 +158,7 @@ def load_hunspell_dicts(self): | |
self.similar = [] | ||
for sim_entry in self.similar_langs: | ||
if sim_entry.split('_')[0] == self.lang: | ||
self.similar.append(self.similar_langs[sim_entry]) | ||
self.similar.append(self.similar_langs[sim_entry] + [self.lang]) | ||
marco-c marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
logging.debug(f"Similar lists for '{self.lang}': {self.similar}") | ||
self.hunspell_objs = {} | ||
|
@@ -208,6 +239,9 @@ def getlang(self, sent): | |
|
||
#TODO: Confidence score? | ||
|
||
logging.debug(prediction) | ||
logging.debug(self.similar) | ||
|
||
if self.similar == [] or prediction not in self.hunspell_objs: | ||
#Non mistakeable language: just return FastText prediction | ||
refined_prediction = prediction | ||
|
@@ -218,6 +252,7 @@ def getlang(self, sent): | |
for sim_list in self.similar: | ||
if prediction in sim_list or f'{prediction}_{script}' in sim_list: | ||
current_similar = sim_list | ||
logging.debug(current_similar) | ||
|
||
spellchecked = {} | ||
for l in current_similar: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are the input features for the model.