Skip to content

Commit

Permalink
Update support for encodings hashes when hosting cache files.
Browse files Browse the repository at this point in the history
  • Loading branch information
blaney83 committed Mar 8, 2024
1 parent d8ce942 commit c9edc30
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 32 deletions.
31 changes: 21 additions & 10 deletions tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def check_hash(data: bytes, expected_hash: str) -> bool:
return actual_hash == expected_hash


def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> bytes:
def read_file_cached(
blobpath: str,
expected_hash: Optional[str] = None,
is_self_hosting: Optional[bool] = False
) -> bytes:
user_specified_cache = True
if "TIKTOKEN_CACHE_DIR" in os.environ:
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
Expand All @@ -52,9 +56,20 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> byte
if os.path.exists(cache_path):
with open(cache_path, "rb") as f:
data = f.read()
if expected_hash is None or check_hash(data, expected_hash):
if expected_hash is None:
return data

if check_hash(data, expected_hash):
return data

if is_self_hosting:
raise ValueError(
f"Hash mismatch for data from {blobpath} (expected {expected_hash}). "
f"This may indicate change in the `tiktoken` encodings for this version. "
f"Please update the hosted encodings or remove/unset the `ENCODINGS_HOST` "
"to attempt to refresh the cache from the central host (`openaipublic`)."
)

# the cached file does not match the hash, remove it and re-fetch
try:
os.remove(cache_path)
Expand Down Expand Up @@ -83,10 +98,8 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> byte


def data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file: str,
encoder_json_file: str,
vocab_bpe_hash: Optional[str] = None,
encoder_json_hash: Optional[str] = None,
vocab_bpe_contents: str,
encoder_json_contents: str,
) -> dict[bytes, int]:
# NB: do not add caching to this function
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
Expand All @@ -101,7 +114,6 @@ def data_gym_to_mergeable_bpe_ranks(
assert len(rank_to_intbyte) == 2**8

# vocab_bpe contains the merges along with associated ranks
vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode()
bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]

def decode_data_gym(value: str) -> bytes:
Expand All @@ -118,7 +130,7 @@ def decode_data_gym(value: str) -> bytes:
# check that the encoder file matches the merges file
# this sanity check is important since tiktoken assumes that ranks are ordered the same
# as merge priority
encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash))
encoder_json = json.loads(encoder_json_contents)
encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
# drop these two special tokens if present, since they're not mergeable bpe tokens
encoder_json_loaded.pop(b"<|endoftext|>", None)
Expand All @@ -141,10 +153,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No


def load_tiktoken_bpe(
tiktoken_bpe_file: str, expected_hash: Optional[str] = None
contents:bytes
) -> dict[bytes, int]:
# NB: do not add caching to this function
contents = read_file_cached(tiktoken_bpe_file, expected_hash)
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
Expand Down
62 changes: 40 additions & 22 deletions tiktoken_ext/openai_public.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,46 @@
import os
from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe
from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe, read_file_cached

ENDOFTEXT = "<|endoftext|>"
FIM_PREFIX = "<|fim_prefix|>"
FIM_MIDDLE = "<|fim_middle|>"
FIM_SUFFIX = "<|fim_suffix|>"
ENDOFPROMPT = "<|endofprompt|>"

ENCODINGS_HOST = os.getenv("ENCODINGS_HOST", "https://openaipublic.blob.core.windows.net")
ENCODINGS_HOST = os.getenv("ENCODINGS_HOST", None)

if "ENCODINGS_HOST" in os.environ:
ENCODINGS_HOST = os.environ["ENCODINGS_HOST"]
IS_HOSTING_ENCODINGS = True
else:
ENCODINGS_HOST = "https://openaipublic.blob.core.windows.net"
IS_HOSTING_ENCODINGS = False

VOCAB_BPE_FILE = f"{ENCODINGS_HOST}/gpt-2/encodings/main/vocab.bpe"
VOCAB_BPE_HASH = "1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5"
ENCODER_JSON_FILE = f"{ENCODINGS_HOST}/gpt-2/encodings/main/encoder.json"
ENCODER_JSON_HASH = "196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783"
R50K_BASE_FILE = f"{ENCODINGS_HOST}/encodings/r50k_base.tiktoken"
R50K_BASE_HASH = "306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930"
P50K_BASE_FILE = f"{ENCODINGS_HOST}/encodings/p50k_base.tiktoken"
P50K_BASE_HASH = "94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069"
CL100K_BASE_FILE = f"{ENCODINGS_HOST}/encodings/cl100k_base.tiktoken"
CL100K_BASE_HASH = "223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7"

def gpt2():
vocab_bpe_contents = read_file_cached(
VOCAB_BPE_FILE,
VOCAB_BPE_HASH,
IS_HOSTING_ENCODINGS
).decode()
encoder_json_contents = read_file_cached(
ENCODER_JSON_FILE,
ENCODER_JSON_HASH,
IS_HOSTING_ENCODINGS
)
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file=f"{ENCODINGS_HOST}/gpt-2/encodings/main/vocab.bpe",
encoder_json_file=f"{ENCODINGS_HOST}/gpt-2/encodings/main/encoder.json",
vocab_bpe_hash="1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5",
encoder_json_hash="196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783",
vocab_bpe_contents= vocab_bpe_contents,
encoder_json_contents=encoder_json_contents
)
return {
"name": "gpt2",
Expand All @@ -29,10 +55,8 @@ def gpt2():


def r50k_base():
mergeable_ranks = load_tiktoken_bpe(
f"{ENCODINGS_HOST}/encodings/r50k_base.tiktoken",
expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930",
)
contents = read_file_cached(R50K_BASE_FILE, R50K_BASE_HASH, IS_HOSTING_ENCODINGS)
mergeable_ranks = load_tiktoken_bpe(contents)
return {
"name": "r50k_base",
"explicit_n_vocab": 50257,
Expand All @@ -43,10 +67,8 @@ def r50k_base():


def p50k_base():
mergeable_ranks = load_tiktoken_bpe(
f"{ENCODINGS_HOST}/encodings/p50k_base.tiktoken",
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
)
contents = read_file_cached(P50K_BASE_FILE, P50K_BASE_HASH, IS_HOSTING_ENCODINGS)
mergeable_ranks = load_tiktoken_bpe(contents)
return {
"name": "p50k_base",
"explicit_n_vocab": 50281,
Expand All @@ -57,10 +79,8 @@ def p50k_base():


def p50k_edit():
mergeable_ranks = load_tiktoken_bpe(
f"{ENCODINGS_HOST}/encodings/p50k_base.tiktoken",
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
)
contents = read_file_cached(P50K_BASE_FILE, P50K_BASE_HASH, IS_HOSTING_ENCODINGS)
mergeable_ranks = load_tiktoken_bpe(contents)
special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283}
return {
"name": "p50k_edit",
Expand All @@ -71,10 +91,8 @@ def p50k_edit():


def cl100k_base():
mergeable_ranks = load_tiktoken_bpe(
f"{ENCODINGS_HOST}/encodings/cl100k_base.tiktoken",
expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7",
)
contents = read_file_cached(CL100K_BASE_FILE, CL100K_BASE_HASH, IS_HOSTING_ENCODINGS)
mergeable_ranks = load_tiktoken_bpe(contents)
special_tokens = {
ENDOFTEXT: 100257,
FIM_PREFIX: 100258,
Expand Down

0 comments on commit c9edc30

Please sign in to comment.