Skip to content

Commit

Permalink
feat:add nltk model installation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
christinestraub committed Jan 6, 2025
1 parent 30198a7 commit 39187d0
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ jobs:
fi
# FIXME (yao): sometimes there is cache but we still miss argilla in the env; so we add make install-ci again
make install-ci
make install-nltk-models
make test CI=true UNSTRUCTURED_INCLUDE_DEBUG_METADATA=true
make check-coverage
Expand Down Expand Up @@ -317,6 +318,7 @@ jobs:
tesseract --version
make install-all-docs
make install-ingest
make install-nltk-models
./test_unstructured_ingest/test-ingest-src.sh
Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ install-huggingface:

.PHONY: install-nltk-models
install-nltk-models:
${PYTHON} -c "from unstructured.nlp.tokenize import copy_nltk_packages; copy_nltk_packages()"
export NLTK_DATA=/home/notebook-user/nltk_data && \
${python} -m nltk.downloader punkt_tab averaged_perceptron_tagger_eng

.PHONY: install-test
install-test:
Expand Down
23 changes: 0 additions & 23 deletions unstructured/nlp/tokenize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import os
import shutil
from functools import lru_cache
from typing import Final, List, Tuple

Expand All @@ -16,26 +15,6 @@
NLTK_DATA_PATH = os.getenv("NLTK_DATA", "/home/notebook-user/nltk_data")
nltk.data.path.append(NLTK_DATA_PATH)

PROJECT_NLTK_ASSETS_PATH = os.path.abspath("../../nltk_data")


def copy_nltk_packages():
if os.path.exists(PROJECT_NLTK_ASSETS_PATH):
if not os.path.exists(NLTK_DATA_PATH):
os.makedirs(NLTK_DATA_PATH)
for item in os.listdir(PROJECT_NLTK_ASSETS_PATH):
s = os.path.join(PROJECT_NLTK_ASSETS_PATH, item)
d = os.path.join(NLTK_DATA_PATH, item)
if os.path.isdir(s):
shutil.copytree(s, d, dirs_exist_ok=True)
else:
shutil.copy2(s, d)
print(f"NLTK data copied to {NLTK_DATA_PATH}")
else:
raise RuntimeError(
f"Local NLTK data path does not exist: {PROJECT_NLTK_ASSETS_PATH}"
)


def check_for_nltk_package(package_name: str, package_category: str) -> bool:
"""Checks to see if the specified NLTK package exists on the file system."""
Expand All @@ -48,8 +27,6 @@ def check_for_nltk_package(package_name: str, package_category: str) -> bool:

# Ensure NLTK data exists in the specified path (pre-baked in Docker)
def validate_nltk_assets():
if not os.path.exists(NLTK_DATA_PATH):
copy_nltk_packages()
"""Validate that required NLTK packages are preloaded in the image."""
required_assets = [
("punkt_tab", "tokenizers"),
Expand Down

0 comments on commit 39187d0

Please sign in to comment.