-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
41 changed files
with
3,692 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# python | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
venv | ||
|
||
# training and test data | ||
data/* | ||
|
||
# temporary outputs | ||
src/vec2sent/evaluation/bleu/bleu-hypothesis | ||
src/vec2sent/evaluation/bleu/bleu-reference | ||
|
||
# IDE | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[submodule "src/external/mos/mos"] | ||
path = src/external/mos/mos | ||
url = https://github.com/zihangdai/mos.git | ||
[submodule "src/external/InferSent"] | ||
path = src/external/InferSent | ||
url = https://github.com/facebookresearch/InferSent.git | ||
[submodule "src/external/geometric_embedding/geometric_embedding"] | ||
path = src/external/geometric_embedding/geometric_embedding | ||
url = https://github.com/fursovia/geometric_embedding.git | ||
[submodule "src/external/quick_thought/S2V"] | ||
path = src/external/quick_thought/S2V | ||
url = https://github.com/lajanugen/S2V.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3.7 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[tool.hatch.build.targets.wheel] | ||
packages = ["src/vec2sent"] | ||
|
||
# This moves the contents of src/external into src/vec2sent during build | ||
[tool.hatch.metadata] | ||
allow-direct-references = true | ||
|
||
[tool.hatch.build.targets.wheel.force-include] | ||
"src/external" = "src/vec2sent" | ||
|
||
[project] | ||
name = "vec2sent" | ||
version = "0.1.0" | ||
description = "Generate sentences from embeddings and evaluate the results." | ||
authors = [ | ||
{ name="Martin Kerscher" }, | ||
] | ||
requires-python = "<3.8" | ||
dependencies = [ | ||
"bpemb>=0.3.5", | ||
"fastBPE>=0.0.0", | ||
"gensim==3.4.0", | ||
"nltk==3.4.1", | ||
"numpy>=1.17.1", | ||
"pytorch-transformers==1.1.0", | ||
"laserembeddings", | ||
"scikit-learn==1.0.2", | ||
"sentence-transformers==0.2.2", | ||
"torch==1.1.0", | ||
"tensorflow==1.15.0", | ||
"protobuf==3.14.0", | ||
"pandas==1.2.0", | ||
"tqdm==4.42.1", | ||
"huggingface-hub>=0.16.4", | ||
"sent2vec @ git+https://github.com/epfml/sent2vec.git", | ||
"pyemd==0.5.1", | ||
"pytorch-pretrained-bert==0.6.2", | ||
"platformdirs", | ||
"gdown==4.7.3", | ||
"requests", | ||
] | ||
|
||
[project.optional-dependencies] | ||
evaluate_linguistic_features = [ | ||
"spacy" | ||
] | ||
|
||
[project.scripts] | ||
vec2sent_cleanup = "vec2sent.sentence_embeddings.cache_utils:cleanup" | ||
vec2sent_generate = "vec2sent.lstm.generate:main" | ||
vec2sent_evaluate = "vec2sent.evaluation.__main__:main" | ||
vec2sent_arithmetic = "vec2sent.scripts.vector_arithmetic:main" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from vec2sent.sys_path_hack import add_to_path | ||
|
||
add_to_path(["geometric_embedding", "geometric_embedding"]) | ||
|
||
from .geometric_embedding.gem import SentenceEmbedder |
Submodule geometric_embedding
added at
7a84ef
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from vec2sent.sys_path_hack import add_to_path | ||
|
||
add_to_path(["mos", "mos"]) | ||
|
||
from vec2sent.mos.mos.embed_regularize import embedded_dropout | ||
from vec2sent.mos.mos.model import RNNModel | ||
from vec2sent.mos.mos.weight_drop import WeightDrop |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from vec2sent.sys_path_hack import add_to_path | ||
|
||
add_to_path(["quick_thought", "S2V", "src"]) | ||
|
||
from .S2V.src import encoder_manager, configuration |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import sys | ||
from pathlib import Path | ||
from typing import List | ||
|
||
def add_to_path(path: List[str]) -> None: | ||
""" | ||
This is unfortunately necessary because I am importing from a lot of old code that is not organized as a python module. | ||
A lot of old code uses absolute imports that assume the script is executed from the root folder of the repository. | ||
The only solution is to add the root folder to pythonpath. | ||
@param path: list of folders inside src/external containing the code (i.e. ["mos", "mos"] -> src/external/mos/mos) | ||
""" | ||
current_dir = Path(__file__).resolve().parent | ||
sys.path.append(str(current_dir.joinpath(*path))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import logging | ||
from tqdm import tqdm | ||
from nltk.tokenize import word_tokenize | ||
|
||
from vec2sent.sentence_embeddings.abstract_sentence_embedding import AbstractEmbedding | ||
from vec2sent.util.embedding_wrapper import EmbeddingWrapper | ||
from vec2sent.dataset.sentence_dataset import SortedSentenceDataset | ||
|
||
import torch | ||
|
||
|
||
def determine_batch_size(sentence_embeddings: AbstractEmbedding) -> int: | ||
if sentence_embeddings.get_name() in ['randomLSTM', 'borep', 'gem'] or sentence_embeddings.input_strings(): | ||
return 200 | ||
return 1 | ||
|
||
|
||
def load_dataset( | ||
path: str, | ||
word_embeddings: EmbeddingWrapper, | ||
sentence_embeddings: AbstractEmbedding, | ||
device: torch.device, | ||
leave_order: bool, | ||
batch_size: int, | ||
max_len: int, | ||
num_sentences: int = 0, | ||
start_token: str = None | ||
) -> SortedSentenceDataset: | ||
""" | ||
Loads a dataset for training or evaluation. | ||
@param path: path to the file containing the data | ||
@param word_embeddings: word embeddings | ||
@param sentence_embeddings: sentence embeddings | ||
@param device: where to initially load the dataset (might not fit in VRAM) | ||
@param leave_order: whether to leave the dataset in order, or sort it into batches of even length | ||
@param batch_size: batch size | ||
@param max_len: Maximum sentence length in dataset | ||
@param num_sentences: if set to a number > 0, the dataset will be cut off after said number of sentences | ||
@param start_token: String added to each line of the dataset (after tokenization) | ||
""" | ||
|
||
dataset = SortedSentenceDataset(word_embeddings, sentence_embeddings, device, batch_size, max_len) | ||
oov = 0 | ||
|
||
with open(path, "r", encoding="utf-8") as f: | ||
for i, line in enumerate(tqdm(f, total=num_sentences, desc="Loading dataset {}".format(path))): | ||
# Apply basic tokenization to punctuation first | ||
line = " ".join(word_tokenize(line)) | ||
line = line.replace("''", '"').replace("``", '"') | ||
|
||
oov += dataset.add(line, max_len, start_token) | ||
|
||
if i == num_sentences - 1: | ||
break | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.info('Loaded {} sentences'.format(len(dataset))) | ||
dataset.finish_init() | ||
|
||
logger.info("Out of vocabulary: {}".format(oov)) | ||
|
||
pin_memory = device.type != 'cpu' | ||
dataset.create_data_loader(batch_size, leave_order, pin_memory) | ||
return dataset |
Oops, something went wrong.