Skip to content

Commit

Permalink
Add a _byte_pair_merge_large for worst-case scenarios
Browse files Browse the repository at this point in the history
We're storing the ranks in a sorted tree of sorted (or linked) trees.
Getting the minimum rank is logarithmic and each subsequent occurrence is constant time.
To know the previous and next indexes (and the corresponding ranks), we're storing them in arrays (the keys are the indexes). We're updating each after finding the minimum via the tree.
We're iterating duplicates without removing them one-by-one, but if they are neighbors, we're skipping them manually.
  • Loading branch information
Lőrinc committed Feb 11, 2024
1 parent 25e9a4f commit b7c6ac8
Showing 1 changed file with 84 additions and 1 deletion.
85 changes: 84 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// This check is new and seems buggy (possibly with PyO3 interaction)
#![allow(clippy::borrow_deref_ref)]

use std::collections::HashSet;
use std::collections::{BTreeMap, BTreeSet, HashSet};
use std::iter::successors;
use std::num::NonZeroU64;
use std::thread;

Expand All @@ -15,7 +16,17 @@ use rustc_hash::FxHashMap as HashMap;

type Rank = u32;

const LARGE_ENCODER_CHARACTER_LIMIT: usize = 500;

fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
if piece.len() < LARGE_ENCODER_CHARACTER_LIMIT {
_byte_pair_merge_small(ranks, piece) // Quadratic, but lightweight
} else {
_byte_pair_merge_large(ranks, piece) // Linearithmic, but heavy
}
}

fn _byte_pair_merge_small(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the pair starting at position start.
let mut parts = Vec::with_capacity(piece.len() + 1);
Expand Down Expand Up @@ -73,6 +84,78 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize,
parts
}

fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
let mut rank_indexes = BTreeMap::<Rank, BTreeSet<usize>>::new();
let mut index_rank = vec![Rank::MAX; piece.len() + 1];
let mut index_prev = vec![usize::MAX; piece.len() + 1];
let mut index_next = vec![usize::MAX; piece.len() + 1];

let get_rank = |start_idx: usize, end_idx: usize| -> Rank {
*piece.get(start_idx..end_idx)
.and_then(|p| ranks.get(p))
.unwrap_or(&Rank::MAX)
};

let mut prev_node = None;
for i in 0..=piece.len() {
let rank = get_rank(i, i + 2);
index_rank[i] = rank;
if let Some(prev) = prev_node {
index_prev[i] = prev;
index_next[prev] = i;
}
prev_node = Some(i);

rank_indexes.entry(rank).or_default().insert(i);
}

while rank_indexes.len() > 1 {
let mut skip_next = false;
if let Some((_, nodes)) = rank_indexes.pop_first() {
for &min_node in &nodes {
if skip_next {
skip_next = false;
continue;
}

let min_rank = index_rank[min_node];

let prev_node = index_prev[min_node];
let next_node = index_next[min_node];
let next_next_node = index_next[next_node];
let next_next_next_node = index_next[next_next_node];

if prev_node != usize::MAX {
let new_rank = get_rank(prev_node, next_next_node);
if index_rank[prev_node] != new_rank {
rank_indexes.get_mut(&index_rank[prev_node]).unwrap().remove(&prev_node);
index_rank[prev_node] = new_rank;
rank_indexes.entry(new_rank).or_default().insert(prev_node);
}
}

let new_rank = get_rank(min_node, next_next_next_node);
index_rank[min_node] = new_rank;
rank_indexes.entry(new_rank).or_default().insert(min_node);

index_next[min_node] = next_next_node;
index_prev[next_next_node] = min_node;

let next_node_rank = index_rank[next_node];
if next_node_rank == min_rank {
skip_next = true;
} else if next_node_rank != Rank::MAX {
rank_indexes.get_mut(&next_node_rank).unwrap().remove(&next_node);
}
}
}
}

successors(Some(0), |&n| index_next.get(n).filter(|&&x| x != usize::MAX).copied())
.map(|n| (n, Rank::MAX))
.collect()
}

pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
assert!(piece.len() > 1);
_byte_pair_merge(&ranks, &piece)
Expand Down

0 comments on commit b7c6ac8

Please sign in to comment.