Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize byte pair merge for really big tokens (40x faster for a 2500 token word) #239

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// This check is new and seems buggy (possibly with PyO3 interaction)
#![allow(clippy::borrow_deref_ref)]

use std::collections::HashSet;
use std::collections::{BTreeMap, BTreeSet, HashSet};
use std::iter::successors;
use std::num::NonZeroU64;
use std::thread;

Expand All @@ -15,7 +16,17 @@ use rustc_hash::FxHashMap as HashMap;

type Rank = u32;

const LARGE_ENCODER_CHARACTER_LIMIT: usize = 500;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the Java version this was controlled by an environmental variable, which enabled us to run all tests against both implementations - should I do it here as well?


fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
if piece.len() < LARGE_ENCODER_CHARACTER_LIMIT {
_byte_pair_merge_small(ranks, piece) // Quadratic, but lightweight
Copy link
Contributor Author

@l0rinc l0rinc Jan 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quadtratic algo is usually faster for very small words - which is always the case for natural language, but e.g. DNA sequences or a DOS attack can be avoided by switching to the linearithmic algo

} else {
_byte_pair_merge_large(ranks, piece) // Linearithmic, but heavy
}
}

fn _byte_pair_merge_small(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the pair starting at position start.
let mut parts = Vec::with_capacity(piece.len() + 1);
Expand Down Expand Up @@ -73,6 +84,78 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize,
parts
}

fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
let mut rank_indexes = BTreeMap::<Rank, BTreeSet<usize>>::new();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grouped by rank, the values ordered by index (basically a LinkedHashSet inside)

let mut index_rank = vec![Rank::MAX; piece.len() + 1];
Copy link
Contributor Author

@l0rinc l0rinc Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mutations seemed easier this way, compared to creating a struct with index/rank/prev/next - especially in Rust

let mut index_prev = vec![usize::MAX; piece.len() + 1];
let mut index_next = vec![usize::MAX; piece.len() + 1];

let get_rank = |start_idx: usize, end_idx: usize| -> Rank {
*piece.get(start_idx..end_idx)
Copy link
Contributor Author

@l0rinc l0rinc Feb 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when end_idx is out of bounds we're defaulting to Rank::MAX

.and_then(|p| ranks.get(p))
.unwrap_or(&Rank::MAX)
};

let mut prev_node = None;
for i in 0..=piece.len() {
let rank = get_rank(i, i + 2);
index_rank[i] = rank;
if let Some(prev) = prev_node {
index_prev[i] = prev;
index_next[prev] = i;
}
prev_node = Some(i);

rank_indexes.entry(rank).or_default().insert(i);
}

while rank_indexes.len() > 1 {
let mut skip_next = false;
if let Some((_, nodes)) = rank_indexes.pop_first() {
Copy link
Contributor Author

@l0rinc l0rinc Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

next min is popped off in logarithmic time instead of linearly

for &min_node in &nodes {
Copy link
Contributor Author

@l0rinc l0rinc Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicates are processed in bulk (since the next min is strictly greater than equal), no need to remove them one-by-one

if skip_next {
skip_next = false;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when merging neighboring elements with the same ranks

continue;
}

let min_rank = index_rank[min_node];

let prev_node = index_prev[min_node];
let next_node = index_next[min_node];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getting next and previous requires lookups now

let next_next_node = index_next[next_node];
Copy link

@aldanor aldanor Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@l0rinc I think these lines would panic (out-of-range) if your min_node is close to an end, how do you know you have 3 nodes to the right of it? Your last node's next will be a usize::MAX.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, I'm misreading this. The only one that can be "none" is the next_next_next_node; but then it's usize::MAX and the .get() in get_rank() will cover this and return Rank::MAX.

let next_next_next_node = index_next[next_next_node];
Comment on lines +123 to +126
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're keeping track of the order of characters inside the rank-balanced tree, providing logarithmic access to the minimum rank and constant access to the previous/next


if prev_node != usize::MAX {
let new_rank = get_rank(prev_node, next_next_node);
if index_rank[prev_node] != new_rank {
rank_indexes.get_mut(&index_rank[prev_node]).unwrap().remove(&prev_node);
index_rank[prev_node] = new_rank;
rank_indexes.entry(new_rank).or_default().insert(prev_node);
}
}

let new_rank = get_rank(min_node, next_next_next_node);
index_rank[min_node] = new_rank;
rank_indexes.entry(new_rank).or_default().insert(min_node);

index_next[min_node] = next_next_node;
index_prev[next_next_node] = min_node;

let next_node_rank = index_rank[next_node];
if next_node_rank == min_rank {
skip_next = true;
} else if next_node_rank != Rank::MAX {
rank_indexes.get_mut(&next_node_rank).unwrap().remove(&next_node);
}
}
}
}

successors(Some(0), |&n| index_next.get(n).filter(|&&x| x != usize::MAX).copied())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iterate until there's a valid rank

.map(|n| (n, Rank::MAX))
.collect()
}

pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
assert!(piece.len() > 1);
_byte_pair_merge(&ranks, &piece)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,16 @@ def test_simple_regex():
def test_basic_encode():
enc = tiktoken.get_encoding("r50k_base")
assert enc.encode("hello world") == [31373, 995]
assert enc.encode("a" * 1000) == [24794] * 250
Copy link
Contributor Author

@l0rinc l0rinc Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to cover the big encoder as well


enc = tiktoken.get_encoding("p50k_base")
assert enc.encode("hello world") == [31373, 995]
assert enc.encode("a" * 1000) == [24794] * 250

enc = tiktoken.get_encoding("cl100k_base")
assert enc.encode("hello world") == [15339, 1917]
assert enc.encode(" \x850") == [220, 126, 227, 15]
assert enc.encode("a" * 1000) == [70540] * 125


def test_encode_empty():
Expand Down
Loading