diff --git a/example_basic.py b/example_basic.py index 8db3b424..e95d0adb 100644 --- a/example_basic.py +++ b/example_basic.py @@ -3,7 +3,7 @@ from generator import ExLlamaGenerator import os, glob -# Directory containt model, tokenizer, generator +# Directory containing model, tokenizer, generator model_directory = "/mnt/str/models/llama-13b-4bit-128g/" diff --git a/example_batch.py b/example_batch.py new file mode 100644 index 00000000..179cf2f4 --- /dev/null +++ b/example_batch.py @@ -0,0 +1,56 @@ +from model import ExLlama, ExLlamaCache, ExLlamaConfig +from tokenizer import ExLlamaTokenizer +from generator import ExLlamaGenerator +import os, glob + +# Directory containing model, tokenizer, generator + +model_directory = "/mnt/str/models/llama-13b-4bit-128g/" + +# Locate files we need within that directory + +tokenizer_path = os.path.join(model_directory, "tokenizer.model") +model_config_path = os.path.join(model_directory, "config.json") +st_pattern = os.path.join(model_directory, "*.safetensors") +model_path = glob.glob(st_pattern)[0] + +# Batched prompts + +prompts = [ + "Once upon a time,", + "I don't like to", + "A turbo encabulator is a", + "In the words of Mark Twain," +] + +# Create config, model, tokenizer and generator + +config = ExLlamaConfig(model_config_path) # create config from config.json +config.model_path = model_path # supply path to model weights file + +model = ExLlama(config) # create ExLlama instance and load the weights +tokenizer = ExLlamaTokenizer(tokenizer_path) # create tokenizer from tokenizer model file + +cache = ExLlamaCache(model, batch_size = len(prompts)) # create cache for inference +generator = ExLlamaGenerator(model, tokenizer, cache) # create generator + +# Configure generator + +generator.disallow_tokens([tokenizer.eos_token_id]) + +generator.settings.token_repetition_penalty_max = 1.2 +generator.settings.temperature = 0.95 +generator.settings.top_p = 0.65 +generator.settings.top_k = 100 +generator.settings.typical = 0.5 + +# Generate, batched + +for line in prompts: + print(line) + +output = generator.generate_simple(prompts, max_new_tokens = 200) + +for line in output: + print("---") + print(line) diff --git a/exllama_ext/exllama_ext.cpp b/exllama_ext/exllama_ext.cpp index ec330b09..615f7f4f 100644 --- a/exllama_ext/exllama_ext.cpp +++ b/exllama_ext/exllama_ext.cpp @@ -686,6 +686,8 @@ void rep_penalty int vocab_size = rep_mask.size(0); int seq_len = sequence.size(-1); + // TODO: Support batch size + rep_penalty_cpu ( vocab_size, @@ -709,20 +711,25 @@ void apply_rep_penalty { TORCH_CHECK_DTYPE(sequence, kLong); TORCH_CHECK_DTYPE(logits, kFloat); + TORCH_CHECK_SHAPES(sequence, 0, logits, 0, 1); int vocab_size = logits.size(-1); + int bsz = sequence.size(0); int seq_len = sequence.size(-1); - apply_rep_penalty_cpu - ( - vocab_size, - (uint64_t*) sequence.data_ptr(), - penalty_max, - sustain, - decay, - seq_len, - (float*) logits.data_ptr() - ); + for (int i = 0; i < bsz; i++) + { + apply_rep_penalty_cpu + ( + vocab_size, + ((uint64_t*) sequence.data_ptr()) + i * seq_len, + penalty_max, + sustain, + decay, + seq_len, + ((float*) logits.data_ptr()) + i * vocab_size + ); + } } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) diff --git a/generator.py b/generator.py index 8b0dde97..638cba63 100644 --- a/generator.py +++ b/generator.py @@ -61,11 +61,27 @@ def make_rep_mask(self, penalty_max, sustain, decay): return cuda_ext.ext_rep_penalty_mask_cpu(self.model.config.vocab_size, self.sequence, penalty_max, sustain, decay) + def batched_sample(self, logits, temperature, top_k, top_p, min_p, typical, num = 1): + + if logits.shape[0] == 1: return self.sample(logits, temperature, top_k, top_p, min_p, typical, num) + + samples = [] + scores = [] + for i in range(logits.shape[0]): + t, s = self.sample(logits[i, :, :], temperature, top_k, top_p, min_p, typical) + samples.append(t) + scores.append(s) + + return torch.cat(samples, dim = 0), torch.cat(scores, dim = 0) + + def sample(self, logits, temperature, top_k, top_p, min_p, typical, num = 1): # torch.manual_seed(42) - logits = logits[0, -1, :] + if logits.dim() == 3: logits = logits[0, -1, :] + elif logits.dim() == 2: logits = logits[-1, :] + else: raise ValueError("Bad logits dimension") # Disallow tokens @@ -285,7 +301,7 @@ def gen_num_tokens(self): return self.sequence_actual.shape[-1] - # Generate some number of tokens and append to + # Simple generator function def generate_simple(self, prompt, max_new_tokens = 128): @@ -294,11 +310,16 @@ def generate_simple(self, prompt, max_new_tokens = 128): ids = self.tokenizer.encode(prompt) self.gen_begin(ids) + max_new_tokens = min(max_new_tokens, self.model.config.max_seq_len - ids.shape[1]) + + eos = torch.zeros((ids.shape[0],), dtype = torch.bool) for i in range(max_new_tokens): token = self.gen_single_token() - if token.item() == self.tokenizer.eos_token_id: break + for j in range(token.shape[0]): + if token[j, 0].item() == self.tokenizer.eos_token_id: eos[j] = True + if eos.all(): break - text = self.tokenizer.decode(self.sequence[0]) + text = self.tokenizer.decode(self.sequence[0] if self.sequence.shape[0] == 1 else self.sequence) return text @@ -327,12 +348,12 @@ def gen_single_token(self, constraints = None): for c in constraints: logits[:, :, c] += 10000.0 logits[:, :, :] -= 10000.0 - token, _ = self.sample(logits, - self.settings.temperature, - self.settings.top_k, - self.settings.top_p, - self.settings.min_p + 0.01 if constraints is not None else 0.0, - self.settings.typical) + token, _ = self.batched_sample(logits, + self.settings.temperature, + self.settings.top_k, + self.settings.top_p, + self.settings.min_p + 0.01 if constraints is not None else 0.0, + self.settings.typical) else: diff --git a/tokenizer.py b/tokenizer.py index 2fbc7c47..eeb4fb5f 100644 --- a/tokenizer.py +++ b/tokenizer.py @@ -47,6 +47,7 @@ def decode(self, ids): for i in range(ids.shape[0]): seq = ids[i].tolist() seq = [t for t in seq if t != self.pad_token_id] + if self.eos_token_id in seq: seq = seq[:seq.index(self.eos_token_id)] texts.append(self.tokenizer.Decode(seq)) return texts