From aeca53251c0410ed7869a38a9fd44ee7cee4d162 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Mon, 15 Jan 2024 21:06:29 +0100 Subject: [PATCH 1/2] Add test for encoding huge byte sequences --- tests/test_encoding.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 27b21925..9f313197 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -61,13 +61,16 @@ def test_simple_regex(): def test_basic_encode(): enc = tiktoken.get_encoding("r50k_base") assert enc.encode("hello world") == [31373, 995] + assert enc.encode("a" * 1000) == [24794] * 250 enc = tiktoken.get_encoding("p50k_base") assert enc.encode("hello world") == [31373, 995] + assert enc.encode("a" * 1000) == [24794] * 250 enc = tiktoken.get_encoding("cl100k_base") assert enc.encode("hello world") == [15339, 1917] assert enc.encode(" \x850") == [220, 126, 227, 15] + assert enc.encode("a" * 1000) == [70540] * 125 def test_encode_empty(): From 5af8058ea743b2cb78793755a344a8be12773cc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Mon, 15 Jan 2024 22:03:26 +0100 Subject: [PATCH 2/2] Add a _byte_pair_merge_large for worst-case scenarios We're storing the ranks in a sorted tree of sorted (or linked) trees. Getting the minimum rank is logarithmic and each subsequent occurrence is constant time. To know the previous and next indexes (and the corresponding ranks), we're storing them in arrays (the keys are the indexes). We're updating each after finding the minimum via the tree. We're iterating duplicates without removing them one-by-one, but if they are neighbors, we're skipping them manually. --- src/lib.rs | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index b466edd1..66b068f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ // This check is new and seems buggy (possibly with PyO3 interaction) #![allow(clippy::borrow_deref_ref)] -use std::collections::HashSet; +use std::collections::{BTreeMap, BTreeSet, HashSet}; +use std::iter::successors; use std::num::NonZeroU64; use std::thread; @@ -15,7 +16,17 @@ use rustc_hash::FxHashMap as HashMap; type Rank = u32; +const LARGE_ENCODER_CHARACTER_LIMIT: usize = 500; + fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { + if piece.len() < LARGE_ENCODER_CHARACTER_LIMIT { + _byte_pair_merge_small(ranks, piece) // Quadratic, but lightweight + } else { + _byte_pair_merge_large(ranks, piece) // Linearithmic, but heavy + } +} + +fn _byte_pair_merge_small(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { // This is a vector of (start, rank). // The rank is of the pair starting at position start. let mut parts = Vec::with_capacity(piece.len() + 1); @@ -73,6 +84,78 @@ fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, parts } +fn _byte_pair_merge_large(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { + let mut rank_indexes = BTreeMap::>::new(); + let mut index_rank = vec![Rank::MAX; piece.len() + 1]; + let mut index_prev = vec![usize::MAX; piece.len() + 1]; + let mut index_next = vec![usize::MAX; piece.len() + 1]; + + let get_rank = |start_idx: usize, end_idx: usize| -> Rank { + *piece.get(start_idx..end_idx) + .and_then(|p| ranks.get(p)) + .unwrap_or(&Rank::MAX) + }; + + let mut prev_node = None; + for i in 0..=piece.len() { + let rank = get_rank(i, i + 2); + index_rank[i] = rank; + if let Some(prev) = prev_node { + index_prev[i] = prev; + index_next[prev] = i; + } + prev_node = Some(i); + + rank_indexes.entry(rank).or_default().insert(i); + } + + while rank_indexes.len() > 1 { + let mut skip_next = false; + if let Some((_, nodes)) = rank_indexes.pop_first() { + for &min_node in &nodes { + if skip_next { + skip_next = false; + continue; + } + + let min_rank = index_rank[min_node]; + + let prev_node = index_prev[min_node]; + let next_node = index_next[min_node]; + let next_next_node = index_next[next_node]; + let next_next_next_node = index_next[next_next_node]; + + if prev_node != usize::MAX { + let new_rank = get_rank(prev_node, next_next_node); + if index_rank[prev_node] != new_rank { + rank_indexes.get_mut(&index_rank[prev_node]).unwrap().remove(&prev_node); + index_rank[prev_node] = new_rank; + rank_indexes.entry(new_rank).or_default().insert(prev_node); + } + } + + let new_rank = get_rank(min_node, next_next_next_node); + index_rank[min_node] = new_rank; + rank_indexes.entry(new_rank).or_default().insert(min_node); + + index_next[min_node] = next_next_node; + index_prev[next_next_node] = min_node; + + let next_node_rank = index_rank[next_node]; + if next_node_rank == min_rank { + skip_next = true; + } else if next_node_rank != Rank::MAX { + rank_indexes.get_mut(&next_node_rank).unwrap().remove(&next_node); + } + } + } + } + + successors(Some(0), |&n| index_next.get(n).filter(|&&x| x != usize::MAX).copied()) + .map(|n| (n, Rank::MAX)) + .collect() +} + pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, Rank>) -> Vec { assert!(piece.len() > 1); _byte_pair_merge(&ranks, &piece)