Skip to content

Commit

Permalink
adding readme
Browse files Browse the repository at this point in the history
  • Loading branch information
OliverGrainge committed Sep 4, 2024
1 parent c6919f3 commit 6e69264
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 4 deletions.
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@

# GPT Model Training Repository

## Overview

This repository contains code for training a GPT-based model. The project is structured to handle the configuration, data preparation, model definition, and training processes for a GPT model.

## File Structure

- **`config.py`**: Contains configuration details for the model and training process, such as hyperparameters, file paths, and other settings.
- **`dataset.py`**: Manages dataset loading and preprocessing. This script is responsible for preparing the data pipeline required for model training.
- **`distributed.py`**: Provides functionalities for distributed training, allowing the model to be trained across multiple GPUs or machines.
- **`download_data.py`**: A utility script for downloading the necessary datasets or external files for training.
- **`model.py`**: Defines the architecture of the GPT model, including layers, forward passes, and other components.
- **`train.py`**: The main script for training the model. It includes code to initialize the model, load data, and handle training loops.
- **`utils.py`**: Contains various utility functions that are used throughout the project, such as logging, checkpoint saving, or performance metrics.

## Usage

1. **Install dependencies**: Ensure you have the required libraries installed by running:
```bash
pip install -r requirements.txt
```

2. **Download data**: Use the `download_data.py` script to download the necessary datasets:
```bash
python download_data.py
```

3. **Configure settings**: Adjust settings in `config.py` to suit your specific training environment, such as modifying hyperparameters, data paths, or training options.

4. **Train the model**: Run the `train.py` script to begin training:
```bash
python train.py
```

5. **Distributed training**: If you are training the model across multiple GPUs or machines, ensure that `distributed.py` is properly configured.

## Contributing

Feel free to contribute to this project by submitting a pull request or opening an issue to report bugs or suggest features.

## License

This project is licensed under the MIT License.
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

@dataclass
class DatasetConfig:
root: str="/Users/olivergrainge/Documents/github/GPT/raw_data" # where to find the data
root: str="/mnt/datasets_drive/nlp" # where to find the data
train_small_name: str="input.txt" # name of the small training dataset
train_large_name: str="edu_fineweb10B" # name of the large training dataset
val_name: str="hellaswag" # name of the evaluation dataset
Expand All @@ -25,7 +25,7 @@ class TraningConfig:
samples: bool = True # whether to sample output text on each training pass
compile: bool = False # whether to use torch.compile
lr: float = 6e-4 # maximum learning rate
steps_per_pass: int = 5 # number of training steps before validation
steps_per_pass: int = 1000 # number of training steps before validation
max_steps: int = 1000 # max number of total training steps required
warmup_steps: int=10 # number of lr warmup steps
min_lr_mult: float = 0.1 # minimum lr value
Expand Down
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,11 @@ def samples(training_state, model, text, n_samples=5, max_length=30, logger=None
def validate(training_state, model, val_dl, logger=None):
model.eval()
correct, total = 0, 0
device = utils.detect_device()
for x, y, l in tqdm(val_dl):
with torch.no_grad():
x = x[:, :-1]
targets = x[:, 1:]
x = x[:, :-1].to(device)
targets = x[:, 1:].to(device)
logits, _ = model(x)
log_probs = logits.log_softmax(dim=2)
selected_log_probs = torch.gather(
Expand Down

0 comments on commit 6e69264

Please sign in to comment.