From 1b9faf2779855124f05174adf1383e53689ed94b Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Sun, 11 Feb 2024 00:20:22 -0800 Subject: [PATCH] Simplify byte_pair_merge (#255) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on suggestion in https://github.com/openai/tiktoken/pull/239 (specifically 8f5dd7d) Like that commit, this: - Does the init in a single loop and saves a loop if there are no merges - Simplifies get_rank and no longer uses it in init (so you don't need multiple skip values) Unlike that commit: - We drop optimisations enabled by ignoring single tokens. These didn't show any benefit on benchmarks for me (this makes sense given typical piece sizes, but let me know if that's unexpected!). Given this, I opted for the simpler version. - I preserve some of the comments from the original that I think are still useful Co-authored-by: @paplorinc --------- Co-authored-by: LÅ‘rinc Pap <1841944+paplorinc@users.noreply.github.com> --- src/lib.rs | 96 ++++++++++++++++++++---------------------------------- 1 file changed, 36 insertions(+), 60 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 2b9e15ff..b466edd1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,85 +15,61 @@ use rustc_hash::FxHashMap as HashMap; type Rank = u32; -fn _byte_pair_merge( - ranks: &HashMap, Rank>, - piece: &[u8], -) -> Vec<(usize, Rank)> { +fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { // This is a vector of (start, rank). - // The rank is of the byte pair starting at position start. - // The rank of the last item in the vector is not a valid value. - let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect(); + // The rank is of the pair starting at position start. + let mut parts = Vec::with_capacity(piece.len() + 1); + + // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE + // the way we currently do, this is equivalent. An easy way to break this would be to decouple + // merge priority from token index or to prevent specific token merges. + let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX); + for i in 0..piece.len() - 1 { + let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX); + if rank < min_rank.0 { + min_rank = (rank, i); + } + parts.push((i, rank)); + } + parts.push((piece.len() - 1, Rank::MAX)); + parts.push((piece.len(), Rank::MAX)); let get_rank = { #[inline(always)] - |parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| { - if (start_idx + skip + 2) < parts.len() { - ranks - .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0]) - .copied() + |parts: &Vec<(usize, Rank)>, i: usize| { + if (i + 3) < parts.len() { + // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted + // parts[i + 1], see comment in the main loop. + *ranks + .get(&piece[parts[i].0..parts[i + 3].0]) + .unwrap_or(&Rank::MAX) } else { - None + Rank::MAX } } }; - // We look up the ranks once in the beginning and iteratively update - // them during each merge, which reduces the number of rank lookups. - for i in 0..parts.len() - 2 { - match get_rank(&parts, i, 0) { - Some(rank) => { - // Rank::MAX is a sentinel value and cannot be a valid rank - debug_assert!(rank != Rank::MAX); - parts[i].1 = rank; - } - None => { - continue; - } - }; - } - // If you have n parts and m merges, this does O(mn) work. // We could do something with a heap and do O(m log n) work. - // It is important to consider that n is often small (<100), and as such - // the cache-locality benefits outweigh the algorithmic complexity downsides - // of the `parts` vector data structure above. - - // Note that we hash bytes, not token pairs. As long as we train BPE the way we - // currently do, this is equivalent. An easy way to break this would be to decouple - // merge priority from token index or to prevent specific token merges. - loop { - if parts.len() == 1 { - break; + // n is often very small so considerations like cache-locality outweigh the algorithmic + // complexity downsides of the `parts` vector. + while min_rank.0 != Rank::MAX { + let i = min_rank.1; + // Update parts[i] and parts[i - 1] before removing parts[i + 1], since + // `parts.remove(i + 1)` will thrash the cache. + if i > 0 { + parts[i - 1].1 = get_rank(&parts, i - 1); } + parts[i].1 = get_rank(&parts, i); + parts.remove(i + 1); - // Rank::MAX is a sentinel rank value allowing us to - // take the min more quickly - let mut min_rank: (Rank, usize) = (Rank::MAX, 0); + min_rank = (Rank::MAX, usize::MAX); for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { if rank < min_rank.0 { min_rank = (rank, i); } } - - if min_rank.0 != Rank::MAX { - let i = min_rank.1; - - // NOTE: We are about to remove parts[i + 1]. We do not do it - // yet because there are cache-locality benefits to updating - // parts[i] and parts[i-1] before removing, which could thrash - // the cache. Thus, we update the rank calculation by skipping over - // parts[i + 1], by invoking `get_rank!` with `skip = 1`. - parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX); - if i > 0 { - parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX); - } - - parts.remove(i + 1); - } else { - break; - } } - parts }