-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
104 lines (83 loc) · 2.69 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import torch
import torch.nn as nn
import json
from src.dataloader import get_dataloader, tokenize
from src.custom_word2vec import CBOWModel, SkipGramModel
from src.trainer import Trainer
from src.metric_monitor import MetricMonitor
# Set path for dataset
TEXT_PATH = os.path.join("dataset", "text8.txt")
# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Max vocabulary size
MAX_VOCAB_SIZE = 5000
# Number of epochs
EPOCHS = 5
# Model type
MODEL_TYPE = "cbow" # or "skipgram"
# Embedding (vector) size
EMBEDDING_SIZE = 100
# Save path
SAVE_PATH = "results"
if __name__ == "__main__":
# Read input text, tokenize and build vocabulary
with open(TEXT_PATH, "r") as f:
raw_txt = f.read()
vocab, tokens = tokenize(inp=raw_txt, vocab_size=MAX_VOCAB_SIZE, default_token="<unk>")
VOCAB_SIZE = min(MAX_VOCAB_SIZE, len(vocab))
# Dataloaders
train_dataloader = get_dataloader(
tokens=tokens,
model_type=MODEL_TYPE,
loader_type="train",
vocab=vocab
)
val_dataloader = get_dataloader(
tokens=tokens,
model_type=MODEL_TYPE,
loader_type="val",
vocab=vocab
)
if MODEL_TYPE == "cbow":
model = CBOWModel(vocab_size=VOCAB_SIZE, embedding_size=EMBEDDING_SIZE)
elif MODEL_TYPE == "skipgram":
model = SkipGramModel(vocab_size=VOCAB_SIZE, embedding_size=EMBEDDING_SIZE)
else:
raise NotImplementedError
# Loss function
criterion = nn.CrossEntropyLoss()
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
metric_monitor = MetricMonitor(
epochs=EPOCHS
)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
criterion=criterion,
optimizer=optimizer,
device=DEVICE,
metric_monitor=metric_monitor,
epochs=EPOCHS
)
trainer.train()
# Save vocabulary
vocab_path = os.path.join(SAVE_PATH, MODEL_TYPE, "vocab.json")
with open(vocab_path, "w") as f:
json.dump(vocab.get_stoi(), f)
# Save tokens
tokens_path = os.path.join(SAVE_PATH, MODEL_TYPE, "tokens.txt")
with open(tokens_path, "w") as f:
f.write(" ".join(tokens))
# Save metrics
metrics_path = os.path.join(SAVE_PATH, MODEL_TYPE, "metrics.json")
with open(metrics_path, "w") as f:
json.dump(metric_monitor.metrics, f)
# Save model
model_path = os.path.join(SAVE_PATH, MODEL_TYPE, "model.pth")
torch.save(model, model_path)
# Save model's weights
model_w_path = os.path.join(SAVE_PATH, MODEL_TYPE, "model_state.pth")
torch.save(model.state_dict(), model_w_path)