Skip to content

Commit

Permalink
feat: openai-style lookup decoding for server
Browse files Browse the repository at this point in the history
  • Loading branch information
Eero Lihavainen committed Mar 1, 2025
1 parent 06c2b15 commit dd3e54f
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 5 deletions.
171 changes: 166 additions & 5 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ struct server_task {
// used by SERVER_TASK_TYPE_INFERENCE
slot_params params;
llama_tokens prompt_tokens;
llama_tokens prediction_tokens;
int id_selected_slot = -1;

// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
Expand Down Expand Up @@ -604,6 +605,7 @@ struct server_task_result_cmpl_final : server_task_result {
int32_t n_decoded;
int32_t n_prompt_tokens;
int32_t n_tokens_cached;
int32_t n_lookup_used;
bool has_new_line;
std::string stopping_word;
stop_type stop = STOP_TYPE_NONE;
Expand Down Expand Up @@ -660,6 +662,7 @@ struct server_task_result_cmpl_final : server_task_result {
{"stopping_word", stopping_word},
{"tokens_cached", n_tokens_cached},
{"timings", timings.to_json()},
{"prediction_tokens_accepted", n_lookup_used},
};
if (!stream && !probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
Expand Down Expand Up @@ -695,7 +698,10 @@ struct server_task_result_cmpl_final : server_task_result {
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens}
{"total_tokens", n_decoded + n_prompt_tokens},
{"completion_tokens_details", json {
{"accepted_prediction_tokens", n_lookup_used },
}}
}},
{"id", oaicompat_cmpl_id}
};
Expand Down Expand Up @@ -771,11 +777,14 @@ struct server_task_result_cmpl_final : server_task_result {
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens}
{"total_tokens", n_decoded + n_prompt_tokens},
{"completion_tokens_details", json {
{"accepted_prediction_tokens", n_lookup_used },
}}
}},
{"id", oaicompat_cmpl_id}
};

// extra fields for debugging purposes
if (verbose) {
res["__verbose"] = to_json_non_oaicompat();
Expand Down Expand Up @@ -811,6 +820,9 @@ struct server_task_result_cmpl_final : server_task_result {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
{"completion_tokens_details", json {
{"accepted_prediction_tokens", n_lookup_used },
}}
}},
};

Expand Down Expand Up @@ -1235,16 +1247,22 @@ struct server_slot {
int32_t n_ctx = 0; // context size per slot
int32_t n_past = 0;
int32_t n_decoded = 0;
int32_t n_lookup_used = 0;
int32_t n_remaining = -1;
int32_t i_batch = -1;
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict

// for "predicted outputs"
int32_t lookup_n_adaptive = 1;
int32_t lookup_index = 0;

// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0;

// input prompt tokens
llama_tokens prompt_tokens;
llama_tokens prediction_tokens;

size_t last_nl_pos = 0;

Expand Down Expand Up @@ -1912,9 +1930,8 @@ struct server_context {
slot.n_ctx = n_ctx_slot;
slot.n_predict = params_base.n_predict;

slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);

slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
Expand Down Expand Up @@ -2034,6 +2051,7 @@ struct server_context {
slot.task_type = task.type;
slot.params = std::move(task.params);
slot.prompt_tokens = std::move(task.prompt_tokens);
slot.prediction_tokens = std::move(task.prediction_tokens);

if (!are_lora_equal(task.params.lora, slot.lora)) {
// if lora is changed, we cannot reuse cached tokens
Expand Down Expand Up @@ -2345,6 +2363,7 @@ struct server_context {
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;
res->n_tokens_cached = slot.n_past;
res->n_lookup_used = slot.n_lookup_used;
res->has_new_line = slot.has_new_line;
res->stopping_word = slot.stopping_word;
res->stop = slot.stop;
Expand Down Expand Up @@ -3217,6 +3236,137 @@ struct server_context {
}
}

// apply "predicted outputs" i.e. user-specified speculation
// using a simple lookup decoding method
for (auto & slot : slots) {
// don't use lookup if we are also using a draft model
if (slot.can_speculate() || !slot.is_processing() || slot.prediction_tokens.size() < 2) {
continue;
}
if (slot.state != SLOT_STATE_GENERATING) {
continue;
}

// adaptive speculation window:
// increase window size every time all drafted tokens were accepted,
// otherwise reset to zero
auto draft_start_pos = 1;
bool found = false;
// first look for a match from the expected position
SLT_DBG(slot, "Looking up prediction tokens at index %d/%d\n", (int) slot.lookup_index, (int) slot.prediction_tokens.size());
if (slot.lookup_index > 0 &&
slot.lookup_index < static_cast<int32_t>(slot.prediction_tokens.size()) &&
slot.prediction_tokens[slot.lookup_index-1] == slot.sampled) {
found = true;
draft_start_pos = slot.lookup_index;
// TODO what is a good scaling law here?
// going for too large windows too fast will likely fail,
// but also too small windows in the beginning hurt perf
slot.lookup_n_adaptive = std::max(16, slot.lookup_n_adaptive*2);
} else {
// find first match in prediction_tokens
slot.lookup_n_adaptive = 1; // default
for (; draft_start_pos < static_cast<int32_t>(slot.prediction_tokens.size()); draft_start_pos++) {
if (slot.prediction_tokens[draft_start_pos-1] == slot.sampled) {
found = true;
break;
}
}
}
if (!found) continue;

// we erase the accepted tokens later, so we're looking for the same position next time
// increment by one because the next token will be generated
slot.lookup_index = draft_start_pos + 1;

llama_tokens draft = std::vector(
slot.prediction_tokens.begin() + draft_start_pos,
slot.prediction_tokens.end()
);

// determine the max draft that fits the current slot state
int n_draft_max = slot.lookup_n_adaptive;
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);

if (slot.n_remaining > 0) {
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
}

n_draft_max = std::min(n_draft_max, static_cast<int>(draft.size()));
// NOTE: we use speculative.n_max here as the upper limit, but
// in general we want to allow large drafts, as opposed to when
// using a draft model. But this is linked to `slot.batch_spec`
// size also.
n_draft_max = std::min(n_draft_max, slot.params.speculative.n_max);

SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);

draft.resize(n_draft_max);

llama_token id = slot.sampled;

// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);

for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
}

llama_decode(ctx, slot.batch_spec);

// the accepted tokens from the speculation
// TODO can we stream these? Would be nice to reduce jankiness in UIs
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);

const auto n_accepted = ids.size() - 1;
slot.n_lookup_used += n_accepted;

if (n_accepted > 0) {
// remove the prediction tokens that were used + the next token
// (because it will be generated)
slot.prediction_tokens.erase(
slot.prediction_tokens.begin() + draft_start_pos,
std::min(
slot.prediction_tokens.end(),
slot.prediction_tokens.begin() + draft_start_pos + n_accepted + 1
)
);
if (n_accepted < draft.size()) {
// reset speculation as we didn't use the full draft
slot.lookup_n_adaptive = 1;
}
}

for (size_t i = 0; i < ids.size(); ++i) {
// NOTE: we need to update these here to avoid stopping early
slot.n_past++;
slot.n_decoded++;
completion_token_output result;

result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // set later

// TODO: set result.probs
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}

slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);

llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);

SLT_DBG(slot, "accepted %d/%d prediction tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
}

// do speculative decoding
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) {
Expand Down Expand Up @@ -3838,10 +3988,17 @@ int main(int argc, char ** argv) {

try {
const auto & prompt = data.at("prompt");
const auto & prediction_obj = json_value(data, "prediction", json());
const auto & prediction = json_value(prediction_obj, "content", std::string());
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());

std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
std::vector<llama_tokens> tokenized_prediction;
if (!prediction.empty()) {
tokenized_prediction = tokenize_input_prompts(ctx_server.vocab, prediction, true, true);
}

tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type);
Expand All @@ -3850,6 +4007,10 @@ int main(int argc, char ** argv) {
task.index = i;

task.prompt_tokens = std::move(tokenized_prompts[i]);

if (!tokenized_prediction.empty()) {
task.prediction_tokens = std::vector(tokenized_prediction[0].begin(), tokenized_prediction[0].end());
}
task.params = server_task::params_from_json_cmpl(
ctx_server.ctx,
ctx_server.params_base,
Expand Down
62 changes: 62 additions & 0 deletions examples/server/tests/unit/test_predicted_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from utils import *


@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.draft_max = 1024
server.debug = True


def test_with_and_without_prediced_outputs():
global server
server.start()
res = server.make_request("POST", "/v1/chat/completions", data={
"messages": [{"role": "user", "content": "I believe the meaning of life is"}],
"temperature": 0.0,
"top_k": 1,
})
assert res.status_code == 200
assert res.body["usage"]["completion_tokens_details"]["accepted_prediction_tokens"] == 0
content_no_pred = res.body["choices"][0]["message"]["content"]
server.stop()

server.start()
res = server.make_request("POST", "/v1/chat/completions", data={
"messages": [{"role": "user", "content": "I believe the meaning of life is"}],
"temperature": 0.0,
"top_k": 1,
"prediction": {"content": '''"Here?" Annabyed.
"Okay, Annabyes!" Annabyed.
As Annagged, Annap came and said,'''}
})
assert res.status_code == 200
assert res.body["usage"]["completion_tokens_details"]["accepted_prediction_tokens"] == 54
content_pred = res.body["choices"][0]["message"]["content"]
server.stop()

assert content_no_pred == content_pred


@pytest.mark.parametrize("n_slots,n_requests", [
(1, 2),
(2, 2),
])
def test_multi_requests_parallel(n_slots: int, n_requests: int):
global server
server.n_slots = n_slots
server.start()
tasks = []
for _ in range(n_requests):
res = server.make_request("POST", "/v1/chat/completions", data={
"messages": [{"role": "user", "content": "I believe the meaning of life is"}],
"temperature": 0.0,
"top_k": 1,
"prediction": {"content": " believe the meaning of life is"}
})
results = parallel_function_calls(tasks)
for res in results:
assert res.status_code == 200
assert match_regex("(wise|kind|owl|answer)+", res.body["content"])

0 comments on commit dd3e54f

Please sign in to comment.