-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_config.py
89 lines (69 loc) · 3.07 KB
/
model_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from gpt_download import download_and_load_gpt2
from previous_chapters import GPTModel, load_weights_into_gpt
class ModelConfig:
def __init__(self, model_name, train_dataset_max_length):
self.model_name = model_name
self.base_config = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True, # Query-key-value bias
}
# Model-specific configurations
self.model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
self._validate_model_name()
self.base_config.update(self.model_configs[self.model_name])
self._validate_context_length(train_dataset_max_length)
def _validate_model_name(self):
if self.model_name not in self.model_configs:
raise ValueError(
f"Invalid model name '{self.model_name}'. Choose from: "
f"{list(self.model_configs.keys())}"
)
def _validate_context_length(self, train_dataset_max_length):
if train_dataset_max_length > self.base_config["context_length"]:
raise ValueError(
f"Dataset length {train_dataset_max_length} exceeds model's "
f"context length {self.base_config['context_length']}. "
f"Reinitialize data sets with max_length="
f"{self.base_config['context_length']}"
)
def get_config(self):
return self.base_config
def initialize_gpt_model(model_name, train_dataset_max_length, models_dir="gpt2"):
"""
Initializes the GPT model with pre-trained weights.
Args:
model_name (str): Name of the GPT model.
train_dataset_max_length (int): Max sequence length in training data.
models_dir (str): Directory to store downloaded models.
Returns:
torch.nn.Module: The GPT model in evaluation mode.
"""
# Step 1: Load configuration
config = ModelConfig(model_name, train_dataset_max_length).get_config()
# Step 2: Download weights and settings
model_size = model_name.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir=models_dir)
# Step 3: Initialize GPT model
model = GPTModel(config)
load_weights_into_gpt(model, params)
# Set to evaluation mode
model.eval()
return model
# Main execution
if __name__ == "__main__":
CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves"
# Initialize the GPT model
model = initialize_gpt_model(
model_name=CHOOSE_MODEL,
train_dataset_max_length=train_dataset.max_length,
models_dir="gpt2"
)
print(f"{CHOOSE_MODEL} initialized and ready for evaluation.")