Skip to content

Commit

Permalink
Simplify byte_pair_merge (#255)
Browse files Browse the repository at this point in the history
Based on suggestion in #239
(specifically 8f5dd7d)

Like that commit, this:
- Does the init in a single loop and saves a loop if there are no merges
- Simplifies get_rank and no longer uses it in init (so you don't need
multiple skip values)

Unlike that commit:
- We drop optimisations enabled by ignoring single tokens. These didn't
show any benefit on benchmarks for me (this makes sense given typical
piece sizes, but let me know if that's unexpected!). Given this, I opted
for the simpler version.
- I preserve some of the comments from the original that I think are
still useful

Co-authored-by: @paplorinc

---------

Co-authored-by: Lőrinc Pap <[email protected]>
  • Loading branch information
hauntsaninja and l0rinc authored Feb 11, 2024
1 parent 6defed5 commit 1b9faf2
Showing 1 changed file with 36 additions and 60 deletions.
96 changes: 36 additions & 60 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,85 +15,61 @@ use rustc_hash::FxHashMap as HashMap;

type Rank = u32;

fn _byte_pair_merge(
ranks: &HashMap<Vec<u8>, Rank>,
piece: &[u8],
) -> Vec<(usize, Rank)> {
fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the byte pair starting at position start.
// The rank of the last item in the vector is not a valid value.
let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect();
// The rank is of the pair starting at position start.
let mut parts = Vec::with_capacity(piece.len() + 1);

// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
// the way we currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
for i in 0..piece.len() - 1 {
let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
if rank < min_rank.0 {
min_rank = (rank, i);
}
parts.push((i, rank));
}
parts.push((piece.len() - 1, Rank::MAX));
parts.push((piece.len(), Rank::MAX));

let get_rank = {
#[inline(always)]
|parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| {
if (start_idx + skip + 2) < parts.len() {
ranks
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
.copied()
|parts: &Vec<(usize, Rank)>, i: usize| {
if (i + 3) < parts.len() {
// Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
// parts[i + 1], see comment in the main loop.
*ranks
.get(&piece[parts[i].0..parts[i + 3].0])
.unwrap_or(&Rank::MAX)
} else {
None
Rank::MAX
}
}
};

// We look up the ranks once in the beginning and iteratively update
// them during each merge, which reduces the number of rank lookups.
for i in 0..parts.len() - 2 {
match get_rank(&parts, i, 0) {
Some(rank) => {
// Rank::MAX is a sentinel value and cannot be a valid rank
debug_assert!(rank != Rank::MAX);
parts[i].1 = rank;
}
None => {
continue;
}
};
}

// If you have n parts and m merges, this does O(mn) work.
// We could do something with a heap and do O(m log n) work.
// It is important to consider that n is often small (<100), and as such
// the cache-locality benefits outweigh the algorithmic complexity downsides
// of the `parts` vector data structure above.

// Note that we hash bytes, not token pairs. As long as we train BPE the way we
// currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
loop {
if parts.len() == 1 {
break;
// n is often very small so considerations like cache-locality outweigh the algorithmic
// complexity downsides of the `parts` vector.
while min_rank.0 != Rank::MAX {
let i = min_rank.1;
// Update parts[i] and parts[i - 1] before removing parts[i + 1], since
// `parts.remove(i + 1)` will thrash the cache.
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1);
}
parts[i].1 = get_rank(&parts, i);
parts.remove(i + 1);

// Rank::MAX is a sentinel rank value allowing us to
// take the min more quickly
let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
min_rank = (Rank::MAX, usize::MAX);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
}
}

if min_rank.0 != Rank::MAX {
let i = min_rank.1;

// NOTE: We are about to remove parts[i + 1]. We do not do it
// yet because there are cache-locality benefits to updating
// parts[i] and parts[i-1] before removing, which could thrash
// the cache. Thus, we update the rank calculation by skipping over
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX);
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX);
}

parts.remove(i + 1);
} else {
break;
}
}

parts
}

Expand Down

0 comments on commit 1b9faf2

Please sign in to comment.