diff --git a/jarvis/config/spacy.json b/jarvis/config/spacy.json new file mode 100644 index 0000000..44fb619 --- /dev/null +++ b/jarvis/config/spacy.json @@ -0,0 +1,19 @@ +{ + "fr-fr": "fr_core_news_sm", + "en-en": "en_core_web_sm", + "en-us": "en_core_web_sm", + "de-de": "de_core_news_sm", + "tr-tr": "xx_ent_wiki_sm", + "cs-cz": "xx_ent_wiki_sm", + "el-gr": "xx_ent_wiki_sm", + "da-dk": "da_core_news_sm", + "et-ee": "xx_ent_wiki_sm", + "pt-pt": "pt_core_news_sm", + "es-es": "es_core_news_sm", + "nl-nl": "nb_core_news_sm", + "fi-fi": "xx_ent_wiki_sm", + "it-it": "it_core_news_sm", + "pl-pl": "pl_core_news_sm", + "sl-si": "xx_ent_wiki_sm", + "sv-se": "xx_ent_wiki_sm" +} \ No newline at end of file diff --git a/jarvis/utils/languages_utils.py b/jarvis/utils/languages_utils.py index 01f1694..47fd563 100644 --- a/jarvis/utils/languages_utils.py +++ b/jarvis/utils/languages_utils.py @@ -26,3 +26,15 @@ def get_language_full_name(name=None): return config_json.get(name) return 'english' + + +def get_spacy_model(language=None): + spacy_model = json.load(open(path + "/config/spacy.json", encoding='utf-8', mode='r')) + + if language is None: + language = get_language() + + if language in spacy_model: + return spacy_model.get(language) + + return 'xx_ent_wiki_sm' # multi-language model (for unsupported languages) diff --git a/jarvis/utils/nlp_utils.py b/jarvis/utils/nlp_utils.py index ec79e77..b4994be 100644 --- a/jarvis/utils/nlp_utils.py +++ b/jarvis/utils/nlp_utils.py @@ -1,13 +1,10 @@ import spacy -from nltk.corpus import stopwords + +from jarvis.utils import languages_utils def get_spacy_nlp(): - """ - - :return: spacy - """ - nlp = spacy.load("en_core_web_sm") + nlp = spacy.load(languages_utils.get_spacy_model()) return nlp @@ -18,10 +15,3 @@ def get_text_without_stopwords(sentence): filtered_sentence = [w for w in sentence.lower().split() if w not in stop_words] filtered_sentence = " ".join(filtered_sentence) return filtered_sentence - - -def get_text_without_stopwords_nltk(sentence, language='english'): - stop_words = set(stopwords.words(language)) - filtered_sentence = [w for w in sentence.lower().split() if w not in stop_words] - filtered_sentence = " ".join(filtered_sentence) - return filtered_sentence diff --git a/requirements.txt b/requirements.txt index ea8d29a..c61ee7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ nltk~=3.6.2 torch~=1.9.0 numpy~=1.21.1 requests~=2.26.0 -adapt-parser \ No newline at end of file +adapt-parser +spacy~=3.1.1 \ No newline at end of file