Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mean pooling experiment to classifier bonus experiments #406

Merged
merged 4 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ch06/02_bonus_additional-experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ For example,
| 15 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |
| 16 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 |
| 17 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) and `ignore_index` for padding | 96.63% | 99.33% | 95.00% | 0.28 min | A100 |
| 18 | gpt2-small (124M) | pretrained | last + pooled embeddings | last_block | longest train ex. (120) | 97.79% | 99.33% | 96.33% | 0.32 min | A100 |

 

Expand All @@ -52,6 +53,7 @@ You can use the following code to reproduce the experiments:
- Row 15: `python additional_experiments.py --no_padding --batch_size 1 --accumulation_steps 8`
- Row 16: `python additional_experiments.py --disable_causal_mask`
- Row 17: `python additional_experiments.py --ignore_index 50256`
- Row 18: `python additional_experiments.py --average embeddings`

I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes (for the default setting) in case you don't have access to a GPU.

Expand All @@ -70,3 +72,4 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
9. **Padding vs no padding (Row 1 vs. 14 and 15)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 15, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy.
10. **Disabling the causal attention mask (Row 1 vs. 16)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask.
11. **Ignoring the padding indices in the loss and backpropagation (Row 1 vs. 17)**: Setting `--ignore_index 50256` excludes the `|endoftext|` padding tokens in the `cross_entropy` loss function in PyTorch. In this case, it does not have any effect because we replaced the output layers so that the token IDs are either 0 or 1 for the binary classification example. However, this setting is useful when instruction finetuning models in chapter 7.
13. **Averaging the embeddings over all tokens (Row 1 vs. 18)**: Setting `--average_embeddings` will average the embeddings over all tokens. If this option is not used (the default), only the output embeddings at the chosen token position (specified by `--trainable_token_pos`) are considered; for example, the embeddings of the last token. Enabling `--average_embeddings` will mean-pool the embeddings of all tokens into the position chosen by `--trainable_token_pos` (the last token by default). As we can see, this improves the performance from 95.00% to 96.33% with only a minimal increase in run time (0.28 min to 0.32 min) and might be worthwhile considering in practice.
89 changes: 67 additions & 22 deletions ch06/02_bonus_additional-experiments/additional_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,24 @@ def instantiate_model(choose_model, load_weights):


def calc_loss_batch(input_batch, target_batch, model, device,
trainable_token_pos=-1, ignore_index=-100):
trainable_token_pos=-1, ignore_index=-100, average_embeddings=False):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, trainable_token_pos, :] # Logits of last output token

model_output = model(input_batch)
if average_embeddings:
# Average over the sequence dimension (dim=1)
logits = model_output.mean(dim=1)
else:
# Select embeddings at the specified token position
logits = model_output[:, trainable_token_pos, :]

loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
return loss


def calc_loss_loader(data_loader, model, device,
num_batches=None, trainable_token_pos=-1, ignore_index=-100):
num_batches=None, trainable_token_pos=-1,
ignore_index=-100, average_embeddings=False):
total_loss = 0.
if len(data_loader) == 0:
return float("nan")
Expand All @@ -203,7 +212,8 @@ def calc_loss_loader(data_loader, model, device,
if i < num_batches:
loss = calc_loss_batch(
input_batch, target_batch, model, device,
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
average_embeddings=average_embeddings
)
total_loss += loss.item()
else:
Expand All @@ -212,7 +222,8 @@ def calc_loss_loader(data_loader, model, device,


@torch.no_grad() # Disable gradient tracking for efficiency
def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token_pos=-1):
def calc_accuracy_loader(data_loader, model, device, num_batches=None,
trainable_token_pos=-1, average_embeddings=False):
model.eval()
correct_predictions, num_examples = 0, 0

Expand All @@ -223,7 +234,15 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, trainable_token_pos, :] # Logits of last output token

model_output = model(input_batch)
if average_embeddings:
# Average over the sequence dimension (dim=1)
logits = model_output.mean(dim=1)
else:
# Select embeddings at the specified token position
logits = model_output[:, trainable_token_pos, :]

predicted_labels = torch.argmax(logits, dim=-1)

num_examples += predicted_labels.shape[0]
Expand All @@ -234,24 +253,27 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable


def evaluate_model(model, train_loader, val_loader, device,
eval_iter, trainable_token_pos=-1, ignore_index=-100):
eval_iter, trainable_token_pos=-1,
ignore_index=-100, average_embeddings=False):
model.eval()
with torch.no_grad():
train_loss = calc_loss_loader(
train_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
average_embeddings=average_embeddings
)
val_loss = calc_loss_loader(
val_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
average_embeddings=average_embeddings
)
model.train()
return train_loss, val_loss


def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1,
accumulation_steps=1, ignore_index=-100):
accumulation_steps=1, ignore_index=-100, average_embeddings=False):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1
Expand All @@ -263,7 +285,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
for batch_idx, (input_batch, target_batch) in enumerate(train_loader):
loss = calc_loss_batch(
input_batch, target_batch, model, device,
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
average_embeddings=average_embeddings
)

# Use gradient accumulation if accumulation_steps > 1
Expand All @@ -286,7 +309,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter,
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
average_embeddings=average_embeddings
)
train_losses.append(train_loss)
val_losses.append(val_loss)
Expand All @@ -297,8 +321,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
break

# New: Calculate accuracy after each epoch
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token_pos=trainable_token_pos)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token_pos=trainable_token_pos)
train_accuracy = calc_accuracy_loader(
train_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
val_accuracy = calc_accuracy_loader(
val_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
train_accs.append(train_accuracy)
Expand Down Expand Up @@ -359,13 +389,22 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False):
"Which token position to train. Options: 'first', 'last'."
)
)
parser.add_argument(
"--average_embeddings",
action='store_true',
default=False,
help=(
"Average the output embeddings from all tokens instead of using"
" only the embedding at the token position specified by `--trainable_token_pos`."
)
)
parser.add_argument(
"--context_length",
type=str,
default="longest_training_example",
help=(
"The context length of the data inputs."
"Options: 'longest_training_example', 'model_context_length' or integer value."
" Options: 'longest_training_example', 'model_context_length' or integer value."
)
)
parser.add_argument(
Expand Down Expand Up @@ -409,7 +448,6 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False):
"The batch size used for training."
)
)

parser.add_argument(
"--accumulation_steps",
type=int,
Expand All @@ -422,7 +460,6 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False):
" the latter setting uses more iterations."
)
)

parser.add_argument(
"--disable_causal_mask",
action='store_true',
Expand All @@ -431,7 +468,6 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False):
"Disables the causal attention mask."
)
)

parser.add_argument(
"--ignore_index",
type=int,
Expand Down Expand Up @@ -589,7 +625,7 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False):
model, train_loader, val_loader, optimizer, device,
num_epochs=args.num_epochs, eval_freq=50, eval_iter=5,
max_steps=None, trainable_token_pos=args.trainable_token_pos,
accumulation_steps=args.accumulation_steps
accumulation_steps=args.accumulation_steps, average_embeddings=args.average_embeddings
)

end_time = time.time()
Expand All @@ -600,9 +636,18 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False):
# Evaluate model
###############################

train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token_pos=args.trainable_token_pos)
val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token_pos=args.trainable_token_pos)
test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token_pos=args.trainable_token_pos)
train_accuracy = calc_accuracy_loader(
train_loader, model, device,
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
)
val_accuracy = calc_accuracy_loader(
val_loader, model, device,
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
)
test_accuracy = calc_accuracy_loader(
test_loader, model, device,
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
)

print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
Expand Down
Loading