This repository contains an implementation of a Transformer model with a novel Switch Multi-Query Attention mechanism. This architecture is designed to efficiently handle large-scale attention tasks by routing tokens to a subset of attention experts, optimizing both computational resources and model performance.
The Switch Multi-Query Attention Transformer is an advanced neural network model that leverages a mixture-of-experts approach to improve the efficiency and scalability of attention mechanisms. It is particularly suited for tasks involving large sequences and complex attention patterns.
- Multi-Query Attention: Efficiently projects queries into multiple heads while sharing keys and values across heads.
- Switch Mechanism: Routes tokens to a subset of experts using top-k routing, optimizing resource usage.
- Load Balancing: Includes a load-balancing loss to ensure uniform token distribution across experts.
- Residual Connections and Layer Normalization: Enhances model stability and performance.
To use this model, ensure you have Python 3.7+ and PyTorch installed. You can install the necessary dependencies using:
pip install torch
To integrate the Switch Multi-Query Attention Transformer into your project, you can instantiate and use the model as follows:
import torch
from experimental.new_models.switch_attn import Transformer
# Define model parameters
vocab_size = 1000
d_model = 64
num_heads = 8
num_experts = 4
num_layers = 2
# Create a Transformer instance
transformer = Transformer(d_model=d_model, num_heads=num_heads, num_experts=num_experts, num_layers=num_layers, vocab_size=vocab_size)
# Prepare input data
x = torch.randint(0, vocab_size, (2, 16)) # Example input tokens
# Forward pass
output = transformer(x)
print(output) # Output logits
MultiQueryAttention
: Implements multi-query attention with rotary positional embeddings.SwitchMultiQueryAttention
: Routes tokens to multiple experts and computes attention with load balancing.Block
: Combines attention and feed-forward networks with residual connections.Transformer
: Full transformer model with multiple layers of attention and feed-forward blocks.
forward(x, attn_mask=None)
: Computes the forward pass of the model.
To run an example, you can use the provided script in the repository:
python examples/main.py
This script demonstrates how to initialize the model, process input data, and obtain predictions.
Contributions are welcome! Please submit a pull request or open an issue for any bugs or feature requests. Ensure that your code adheres to the project's coding standards and includes appropriate tests.
This project is licensed under the MIT License. See the LICENSE file for more details.