Skip to content

Commit

Permalink
Test basic AI model
Browse files Browse the repository at this point in the history
  • Loading branch information
marco-c committed Nov 23, 2023
1 parent 5bfc20e commit 440da84
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 4 deletions.
200 changes: 200 additions & 0 deletions src/fastspell/ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import os
import io
import sys
import fasttext
import hunspell
import logging
import urllib.request
import pathlib
import timeit
import argparse
import traceback
import logging
from sklearn.feature_extraction import DictVectorizer
import pycountry
import xgboost
from sklearn.model_selection import GridSearchCV

try:
from . import __version__
from .util import logging_setup, remove_unwanted_words, get_hash, check_dir, load_config
from .fastspell import FastSpell
except ImportError:
from fastspell import __version__
from util import logging_setup, remove_unwanted_words, get_hash, check_dir, load_config
from fastspell import FastSpell

class FastSpellAI(FastSpell):
def __init__(self, lang, *args, **kwargs):
super().__init__(lang, *args, **kwargs)

ft_download_url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin"
ft_model_path = "lid.176.bin"
if os.path.exists(ft_model_path):
ft_model = fasttext.load_model(ft_model_path)
else:
urllib.request.urlretrieve(ft_download_url, ft_model_path)
ft_model = fasttext.load_model(ft_model_path)

ft_prefix = "__label__"

fsobj = FastSpellAI("en")

languages = [label[len(ft_prefix):] for label in fsobj.model.get_labels()]

unsupported = []
hunspell_objs = {}
for language in languages:
try:
search_l = None
if language in fsobj.hunspell_codes:
search_l = fsobj.hunspell_codes[language]
elif f"{language}_lat" in fsobj.hunspell_codes:
search_l = fsobj.hunspell_codes[f"{language}_lat"]
elif f"{language}_cyr" in fsobj.hunspell_codes:
search_l = fsobj.hunspell_codes[f"{language}_cyr"]
else:
search_l = language
hunspell_objs[language] = fsobj.search_hunspell_dict(search_l)
except:
unsupported.append(language)

print(len(languages))
print(len(unsupported))
print(unsupported)

prediction = fsobj.model.predict("Ciao, mondo!".lower(), k=3)
print(prediction)
print(prediction[0])
print(prediction[0][0])
print(prediction[0][0][len(ft_prefix):])

sentences = []
labels = []
count = 0
with open("../sentences.csv", "r") as f:
for l in f:
number, language, text = next(f).split("\t")

if language != "ita":
continue

lang = pycountry.languages.get(alpha_3=language)

text = text.replace("\n", " ").strip()
prediction = fsobj.model.predict(text.lower(), k=3)

# print(prediction)

lang0 = prediction[0][0][len(ft_prefix):]
lang0_prob = prediction[1][0]
if len(prediction[0]) >= 2:
lang1 = prediction[0][1][len(ft_prefix):]
lang1_prob = prediction[1][1]
else:
# If there's only one option... Not much to do.
continue
if len(prediction[0]) >= 3:
lang2 = prediction[0][2][len(ft_prefix):]
lang2_prob = prediction[1][2]
else:
lang2 = None
lang2_prob = 0.0

label = None
if lang0 == lang.alpha_2:
label = 0
elif lang1 == lang.alpha_2:
label = 1
elif lang2 == lang.alpha_2:
label = 2

if label is None:
continue

# print(lang0)

raw_tokens = text.strip().split(" ")
if lang0 in hunspell_objs:
tokens = remove_unwanted_words(raw_tokens, lang0)
correct = 0
for token in tokens:
try:
if hunspell_objs[lang0].spell(token):
correct += 1
except UnicodeEncodeError as ex:
pass
lang0_dic_tokens = correct / len(tokens)
else:
lang0_dic_tokens = None

if lang1 in hunspell_objs:
tokens = remove_unwanted_words(raw_tokens, lang1)
correct = 0
for token in tokens:
try:
if hunspell_objs[lang1].spell(token):
correct += 1
except UnicodeEncodeError as ex:
pass
lang1_dic_tokens = correct / len(tokens)
else:
lang1_dic_tokens = None

if lang2 in hunspell_objs:
tokens = remove_unwanted_words(raw_tokens, lang2)

correct = 0
for token in tokens:
try:
if hunspell_objs[lang2].spell(token):
correct += 1
except UnicodeEncodeError as ex:
pass

lang2_dic_tokens = correct / len(tokens)
else:
lang2_dic_tokens = None

sentences.append({
"fastText_lang0": lang0_prob,
"fastText_lang1": lang1_prob,
"fastText_lang2": lang2_prob,
"lang0_dic_tokens": lang0_dic_tokens,
"lang1_dic_tokens": lang1_dic_tokens,
"lang2_dic_tokens": lang2_dic_tokens,
})
labels.append(label)

# count += 1
# if count == 7:
# break

print(len(sentences))

dict_vectorizer = DictVectorizer()
X = dict_vectorizer.fit_transform(sentences)

xgb_model = xgboost.XGBClassifier(n_jobs=10)

clf = GridSearchCV(
xgb_model,
{"max_depth": [1, 2, 4, 6], "n_estimators": [25, 50, 100, 200]},
verbose=1,
n_jobs=1,
)
clf.fit(X, labels)
print(clf.best_score_)
print(clf.best_params_)
print(clf.best_estimator_)

clf.best_estimator_.save_model("model.ubj")

X_try = dict_vectorizer.fit_transform([sentences[0]])
classes = xgb_model.predict(X)
if classes[0] == 0:
print("Lang0 chosen")
elif classes[0] == 1:
print("Lang2 chosen")
elif classes[0] == 2:
print("Lang3 chosen")
2 changes: 2 additions & 0 deletions src/fastspell/config/hunspell.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ hunspell_codes:
da: da_DK
de: de_DE
en: en_GB
el: el_GR
es: es_ES
et: et_ET
fa: fa_IR
Expand Down Expand Up @@ -75,3 +76,4 @@ hunspell_codes:
ur: ur_PK
uz: uz_UZ
yi: yi
vi: vi_VN
43 changes: 39 additions & 4 deletions src/fastspell/fastspell.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
HBS_LANGS = ('hbs', 'sh', 'bs', 'sr', 'hr', 'me')


# logger = logging.getLogger()
# logger.setLevel(logging.DEBUG)

def initialization():
parser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0]), formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=__doc__)
parser.add_argument('lang', type=str)
Expand Down Expand Up @@ -94,6 +97,8 @@ def download_fasttext(self):

def search_hunspell_dict(self, lang_code):
''' Search in the paths for a hunspell dictionary and load it '''
hunspell_obj = None

for p in self.hunspell_paths:
if os.path.exists(f"{p}/{lang_code}.dic") and os.path.exists(f"{p}/{lang_code}.aff"):
try:
Expand All @@ -105,10 +110,36 @@ def search_hunspell_dict(self, lang_code):
logging.error("Failed building Hunspell object for " + lang_code)
logging.error("Aborting.")
exit(1)
else:
raise RuntimeError(f"It does not exist any valid dictionary directory"
f"for {lang_code} in the paths {self.hunspell_paths}."

if hunspell_obj is None:
for p in self.hunspell_paths:
if not os.path.exists(p):
continue

potential_files = [path for path in os.listdir(p) if os.path.basename(path).startswith(lang_code)]
if f"{lang_code}.dic" in potential_files and f"{lang_code}.aff" in potential_files:
dic = lang_code
elif f"{lang_code}_{lang_code.upper()}.dic" in potential_files and f"{lang_code}_{lang_code.upper()}.aff" in potential_files:
dic = f"{lang_code}_{lang_code.upper()}"
elif len(potential_files) == 2:
dic = potential_files[0][:-4]
else:
continue

try:
hunspell_obj = hunspell.Hunspell(dic, hunspell_data_dir=p)
logging.debug(f"Loaded hunspell obj for '{lang_code}' in path: {p + '/' + dic}")
break
except:
logging.error("Failed building Hunspell object for " + dic)
logging.error("Aborting.")
exit(1)

if hunspell_obj is None:
raise RuntimeError(f"It does not exist any valid dictionary directory "
f"for {lang_code} in the paths {self.hunspell_paths}. "
f"Please, execute 'fastspell-download'.")

return hunspell_obj


Expand All @@ -127,7 +158,7 @@ def load_hunspell_dicts(self):
self.similar = []
for sim_entry in self.similar_langs:
if sim_entry.split('_')[0] == self.lang:
self.similar.append(self.similar_langs[sim_entry])
self.similar.append(self.similar_langs[sim_entry] + [self.lang])

logging.debug(f"Similar lists for '{self.lang}': {self.similar}")
self.hunspell_objs = {}
Expand Down Expand Up @@ -208,6 +239,9 @@ def getlang(self, sent):

#TODO: Confidence score?

logging.debug(prediction)
logging.debug(self.similar)

if self.similar == [] or prediction not in self.hunspell_objs:
#Non mistakeable language: just return FastText prediction
refined_prediction = prediction
Expand All @@ -218,6 +252,7 @@ def getlang(self, sent):
for sim_list in self.similar:
if prediction in sim_list or f'{prediction}_{script}' in sim_list:
current_similar = sim_list
logging.debug(current_similar)

spellchecked = {}
for l in current_similar:
Expand Down

0 comments on commit 440da84

Please sign in to comment.