Skip to content

Commit

Permalink
Add batch support to generate_simple(), also example
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Jun 24, 2023
1 parent cdb6f54 commit a01b25c
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 21 deletions.
2 changes: 1 addition & 1 deletion example_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from generator import ExLlamaGenerator
import os, glob

# Directory containt model, tokenizer, generator
# Directory containing model, tokenizer, generator

model_directory = "/mnt/str/models/llama-13b-4bit-128g/"

Expand Down
56 changes: 56 additions & 0 deletions example_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
from generator import ExLlamaGenerator
import os, glob

# Directory containing model, tokenizer, generator

model_directory = "/mnt/str/models/llama-13b-4bit-128g/"

# Locate files we need within that directory

tokenizer_path = os.path.join(model_directory, "tokenizer.model")
model_config_path = os.path.join(model_directory, "config.json")
st_pattern = os.path.join(model_directory, "*.safetensors")
model_path = glob.glob(st_pattern)[0]

# Batched prompts

prompts = [
"Once upon a time,",
"I don't like to",
"A turbo encabulator is a",
"In the words of Mark Twain,"
]

# Create config, model, tokenizer and generator

config = ExLlamaConfig(model_config_path) # create config from config.json
config.model_path = model_path # supply path to model weights file

model = ExLlama(config) # create ExLlama instance and load the weights
tokenizer = ExLlamaTokenizer(tokenizer_path) # create tokenizer from tokenizer model file

cache = ExLlamaCache(model, batch_size = len(prompts)) # create cache for inference
generator = ExLlamaGenerator(model, tokenizer, cache) # create generator

# Configure generator

generator.disallow_tokens([tokenizer.eos_token_id])

generator.settings.token_repetition_penalty_max = 1.2
generator.settings.temperature = 0.95
generator.settings.top_p = 0.65
generator.settings.top_k = 100
generator.settings.typical = 0.5

# Generate, batched

for line in prompts:
print(line)

output = generator.generate_simple(prompts, max_new_tokens = 200)

for line in output:
print("---")
print(line)
27 changes: 17 additions & 10 deletions exllama_ext/exllama_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,8 @@ void rep_penalty
int vocab_size = rep_mask.size(0);
int seq_len = sequence.size(-1);

// TODO: Support batch size

rep_penalty_cpu
(
vocab_size,
Expand All @@ -709,20 +711,25 @@ void apply_rep_penalty
{
TORCH_CHECK_DTYPE(sequence, kLong);
TORCH_CHECK_DTYPE(logits, kFloat);
TORCH_CHECK_SHAPES(sequence, 0, logits, 0, 1);

int vocab_size = logits.size(-1);
int bsz = sequence.size(0);
int seq_len = sequence.size(-1);

apply_rep_penalty_cpu
(
vocab_size,
(uint64_t*) sequence.data_ptr(),
penalty_max,
sustain,
decay,
seq_len,
(float*) logits.data_ptr()
);
for (int i = 0; i < bsz; i++)
{
apply_rep_penalty_cpu
(
vocab_size,
((uint64_t*) sequence.data_ptr()) + i * seq_len,
penalty_max,
sustain,
decay,
seq_len,
((float*) logits.data_ptr()) + i * vocab_size
);
}
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
Expand Down
41 changes: 31 additions & 10 deletions generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,27 @@ def make_rep_mask(self, penalty_max, sustain, decay):
return cuda_ext.ext_rep_penalty_mask_cpu(self.model.config.vocab_size, self.sequence, penalty_max, sustain, decay)


def batched_sample(self, logits, temperature, top_k, top_p, min_p, typical, num = 1):

if logits.shape[0] == 1: return self.sample(logits, temperature, top_k, top_p, min_p, typical, num)

samples = []
scores = []
for i in range(logits.shape[0]):
t, s = self.sample(logits[i, :, :], temperature, top_k, top_p, min_p, typical)
samples.append(t)
scores.append(s)

return torch.cat(samples, dim = 0), torch.cat(scores, dim = 0)


def sample(self, logits, temperature, top_k, top_p, min_p, typical, num = 1):

# torch.manual_seed(42)

logits = logits[0, -1, :]
if logits.dim() == 3: logits = logits[0, -1, :]
elif logits.dim() == 2: logits = logits[-1, :]
else: raise ValueError("Bad logits dimension")

# Disallow tokens

Expand Down Expand Up @@ -285,7 +301,7 @@ def gen_num_tokens(self):
return self.sequence_actual.shape[-1]


# Generate some number of tokens and append to
# Simple generator function

def generate_simple(self, prompt, max_new_tokens = 128):

Expand All @@ -294,11 +310,16 @@ def generate_simple(self, prompt, max_new_tokens = 128):
ids = self.tokenizer.encode(prompt)
self.gen_begin(ids)

max_new_tokens = min(max_new_tokens, self.model.config.max_seq_len - ids.shape[1])

eos = torch.zeros((ids.shape[0],), dtype = torch.bool)
for i in range(max_new_tokens):
token = self.gen_single_token()
if token.item() == self.tokenizer.eos_token_id: break
for j in range(token.shape[0]):
if token[j, 0].item() == self.tokenizer.eos_token_id: eos[j] = True
if eos.all(): break

text = self.tokenizer.decode(self.sequence[0])
text = self.tokenizer.decode(self.sequence[0] if self.sequence.shape[0] == 1 else self.sequence)
return text


Expand Down Expand Up @@ -327,12 +348,12 @@ def gen_single_token(self, constraints = None):
for c in constraints: logits[:, :, c] += 10000.0
logits[:, :, :] -= 10000.0

token, _ = self.sample(logits,
self.settings.temperature,
self.settings.top_k,
self.settings.top_p,
self.settings.min_p + 0.01 if constraints is not None else 0.0,
self.settings.typical)
token, _ = self.batched_sample(logits,
self.settings.temperature,
self.settings.top_k,
self.settings.top_p,
self.settings.min_p + 0.01 if constraints is not None else 0.0,
self.settings.typical)

else:

Expand Down
1 change: 1 addition & 0 deletions tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def decode(self, ids):
for i in range(ids.shape[0]):
seq = ids[i].tolist()
seq = [t for t in seq if t != self.pad_token_id]
if self.eos_token_id in seq: seq = seq[:seq.index(self.eos_token_id)]
texts.append(self.tokenizer.Decode(seq))
return texts

Expand Down

0 comments on commit a01b25c

Please sign in to comment.