Skip to content

An implementation of a switch transformer like Multi-query attention model

License

Notifications You must be signed in to change notification settings

kyegomez/Mixture-of-MQA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Switch Multi-Query Attention Transformer

This repository contains an implementation of a Transformer model with a novel Switch Multi-Query Attention mechanism. This architecture is designed to efficiently handle large-scale attention tasks by routing tokens to a subset of attention experts, optimizing both computational resources and model performance.

Table of Contents

Overview

The Switch Multi-Query Attention Transformer is an advanced neural network model that leverages a mixture-of-experts approach to improve the efficiency and scalability of attention mechanisms. It is particularly suited for tasks involving large sequences and complex attention patterns.

Features

  • Multi-Query Attention: Efficiently projects queries into multiple heads while sharing keys and values across heads.
  • Switch Mechanism: Routes tokens to a subset of experts using top-k routing, optimizing resource usage.
  • Load Balancing: Includes a load-balancing loss to ensure uniform token distribution across experts.
  • Residual Connections and Layer Normalization: Enhances model stability and performance.

Installation

To use this model, ensure you have Python 3.7+ and PyTorch installed. You can install the necessary dependencies using:

pip install torch

Usage

To integrate the Switch Multi-Query Attention Transformer into your project, you can instantiate and use the model as follows:

import torch
from experimental.new_models.switch_attn import Transformer

# Define model parameters
vocab_size = 1000
d_model = 64
num_heads = 8
num_experts = 4
num_layers = 2

# Create a Transformer instance
transformer = Transformer(d_model=d_model, num_heads=num_heads, num_experts=num_experts, num_layers=num_layers, vocab_size=vocab_size)

# Prepare input data
x = torch.randint(0, vocab_size, (2, 16))  # Example input tokens

# Forward pass
output = transformer(x)
print(output)  # Output logits

API Reference

Classes

  • MultiQueryAttention: Implements multi-query attention with rotary positional embeddings.
  • SwitchMultiQueryAttention: Routes tokens to multiple experts and computes attention with load balancing.
  • Block: Combines attention and feed-forward networks with residual connections.
  • Transformer: Full transformer model with multiple layers of attention and feed-forward blocks.

Methods

  • forward(x, attn_mask=None): Computes the forward pass of the model.

Examples

To run an example, you can use the provided script in the repository:

python examples/main.py

This script demonstrates how to initialize the model, process input data, and obtain predictions.

Contributing

Contributions are welcome! Please submit a pull request or open an issue for any bugs or feature requests. Ensure that your code adheres to the project's coding standards and includes appropriate tests.

License

This project is licensed under the MIT License. See the LICENSE file for more details.


About

An implementation of a switch transformer like Multi-query attention model

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages