-
Notifications
You must be signed in to change notification settings - Fork 854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize byte pair merge for really big tokens (40x faster for a 2500 token word) #239
base: main
Are you sure you want to change the base?
Optimize byte pair merge for really big tokens (40x faster for a 2500 token word) #239
Conversation
@@ -61,13 +60,16 @@ def test_simple_regex(): | |||
def test_basic_encode(): | |||
enc = tiktoken.get_encoding("r50k_base") | |||
assert enc.encode("hello world") == [31373, 995] | |||
assert enc.encode("a" * 1000) == [24794] * 250 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to cover the big encoder as well
@@ -15,9 +16,22 @@ use rustc_hash::FxHashMap as HashMap; | |||
|
|||
type Rank = u32; | |||
|
|||
const LARGE_ENCODER_CHARACTER_LIMIT: usize = 500; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the Java version this was controlled by an environmental variable, which enabled us to run all tests against both implementations - should I do it here as well?
.unwrap_or(&Rank::MAX) | ||
}; | ||
|
||
let mut rank_indexes = BTreeMap::<Rank, BTreeSet<usize>>::new(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grouped by rank, the values ordered by index (basically a LinkedHashSet inside)
}; | ||
|
||
let mut rank_indexes = BTreeMap::<Rank, BTreeSet<usize>>::new(); | ||
let mut index_rank = vec![Rank::MAX; piece.len() + 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mutations seemed easier this way, compared to creating a struct with index/rank/prev/next - especially in Rust
let mut token_count = piece.len(); | ||
while token_count > 2 && rank_indexes.len() > 1 { | ||
let mut skip_next = false; | ||
if let Some((_, nodes)) = rank_indexes.pop_first() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
next min is popped off in logarithmic time instead of linearly
while token_count > 2 && rank_indexes.len() > 1 { | ||
let mut skip_next = false; | ||
if let Some((_, nodes)) = rank_indexes.pop_first() { | ||
for &min_node in &nodes { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicates are processed in bulk (since the next min is strictly greater than equal), no need to remove them one-by-one
if let Some((_, nodes)) = rank_indexes.pop_first() { | ||
for &min_node in &nodes { | ||
if skip_next { | ||
skip_next = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when merging neighboring elements with the same ranks
} | ||
) -> Vec<(usize, Rank)> { | ||
if piece.len() < LARGE_ENCODER_CHARACTER_LIMIT { | ||
_byte_pair_merge_small(ranks, piece) // Quadratic, but lightweight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quadtratic algo is usually faster for very small words - which is always the case for natural language, but e.g. DNA sequences or a DOS attack can be avoided by switching to the linearithmic algo
let prev_node = index_prev[min_node]; | ||
let next_node = index_next[min_node]; | ||
let next_next_node = index_next[next_node]; | ||
let next_next_next_node = index_next[next_next_node]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Thanks for keeping a simple to follow history. Most of the commits here are straightforward, I've separated them into different PRs (I've preserved authorship information, but let me know if you'd prefer to re-open them yourself)
|
Thanks a lot for the thorough review, Shantanu. After merging you may want to update the benchmark results in the readme. |
Based on suggestion in #239 (specifically 8f5dd7d) Like that commit, this: - Does the init in a single loop and saves a loop if there are no merges - Simplifies get_rank and no longer uses it in init (so you don't need multiple skip values) Unlike that commit: - We drop optimisations enabled by ignoring single tokens. These didn't show any benefit on benchmarks for me (this makes sense given typical piece sizes, but let me know if that's unexpected!). Given this, I opted for the simpler version. - I preserve some of the comments from the original that I think are still useful Co-authored-by: @paplorinc --------- Co-authored-by: Lőrinc Pap <[email protected]>
24d68bd
to
b7c6ac8
Compare
} | ||
} | ||
|
||
successors(Some(0), |&n| index_next.get(n).filter(|&&x| x != usize::MAX).copied()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iterate until there's a valid rank
@hauntsaninja, I've rebased this PR, removing the merged commits and adjusting the result a bit based on your previous preferences. |
let mut index_next = vec![usize::MAX; piece.len() + 1]; | ||
|
||
let get_rank = |start_idx: usize, end_idx: usize| -> Rank { | ||
*piece.get(start_idx..end_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when end_idx
is out of bounds we're defaulting to Rank::MAX
We're storing the ranks in a sorted tree of sorted (or linked) trees. Getting the minimum rank is logarithmic and each subsequent occurrence is constant time. To know the previous and next indexes (and the corresponding ranks), we're storing them in arrays (the keys are the indexes). We're updating each after finding the minimum via the tree. We're iterating duplicates without removing them one-by-one, but if they are neighbors, we're skipping them manually.
b7c6ac8
to
5af8058
Compare
let min_rank = index_rank[min_node]; | ||
|
||
let prev_node = index_prev[min_node]; | ||
let next_node = index_next[min_node]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getting next and previous requires lookups now
Hi! Any updates on this PR? It'd be great to have this 🙏 @hauntsaninja |
|
||
let prev_node = index_prev[min_node]; | ||
let next_node = index_next[min_node]; | ||
let next_next_node = index_next[next_node]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@l0rinc I think these lines would panic (out-of-range) if your min_node is close to an end, how do you know you have 3 nodes to the right of it? Your last node's next will be a usize::MAX
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind, I'm misreading this. The only one that can be "none" is the next_next_next_node; but then it's usize::MAX and the .get()
in get_rank() will cover this and return Rank::MAX.
Based on suggestion in openai#239 (specifically openai/tiktoken@8f5dd7d) Like that commit, this: - Does the init in a single loop and saves a loop if there are no merges - Simplifies get_rank and no longer uses it in init (so you don't need multiple skip values) Unlike that commit: - We drop optimisations enabled by ignoring single tokens. These didn't show any benefit on benchmarks for me (this makes sense given typical piece sizes, but let me know if that's unexpected!). Given this, I opted for the simpler version. - I preserve some of the comments from the original that I think are still useful Co-authored-by: @paplorinc --------- Co-authored-by: Lőrinc Pap <[email protected]>
Continuing the optimizations started in #237 and #234, migrated from knuddelsgmbh/jtokkit#75, knuddelsgmbh/jtokkit#76, knuddelsgmbh/jtokkit#77.
This commit is mainly meant to address the issue of really big tokens spiraling out of control, see: #195
The original byte pair merge algorithm diverges quickly for longer character sequences in a superlinear way - e.g. a 20_000 character word (having 2500 tokens) can take several seconds to be tokenized.
Or on https://platform.openai.com/tokenizer:
The new algorithm scales so well that it could theoretically process the whole text in a single byte-pair-merge loop without any regex splitting (though it would need a different token set to be optimal since it produces slightly different results, mostly whitespaces, though - and it also consumes a lot more memory and is slower that the current one):
The new algorithm does the minimum search logarithmically and duplicates in constant time, but has a higher setup cost, so we're only using it for extreme cases (if the piece given by the regex is > 500 characters):
The benchmarking was done step-by-step in the Java clone and here retested in the way described in #234
110 multilingual books + some source codes + some big token files:
Before:
After:
From which only the big token files (the purpose of this PR):
Before:
After:
i.e.
40x
faster for 20k character words.And if we combine this with the previous regex optimizations, we're getting the following for the 110 books + sources + big tokens case:
i.e. 50% faster on average after all optimizations.
I recommend reviewing commit-by-commit: