diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2306dc26fe431..b3b58d1fb3918 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -197,6 +197,7 @@ struct server_task { // used by SERVER_TASK_TYPE_INFERENCE slot_params params; llama_tokens prompt_tokens; + llama_tokens prediction_tokens; int id_selected_slot = -1; // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE @@ -604,6 +605,7 @@ struct server_task_result_cmpl_final : server_task_result { int32_t n_decoded; int32_t n_prompt_tokens; int32_t n_tokens_cached; + int32_t n_lookup_used; bool has_new_line; std::string stopping_word; stop_type stop = STOP_TYPE_NONE; @@ -660,6 +662,7 @@ struct server_task_result_cmpl_final : server_task_result { {"stopping_word", stopping_word}, {"tokens_cached", n_tokens_cached}, {"timings", timings.to_json()}, + {"prediction_tokens_accepted", n_lookup_used}, }; if (!stream && !probs_output.empty()) { res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); @@ -695,7 +698,10 @@ struct server_task_result_cmpl_final : server_task_result { {"usage", json { {"completion_tokens", n_decoded}, {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} + {"total_tokens", n_decoded + n_prompt_tokens}, + {"completion_tokens_details", json { + {"accepted_prediction_tokens", n_lookup_used }, + }} }}, {"id", oaicompat_cmpl_id} }; @@ -771,11 +777,14 @@ struct server_task_result_cmpl_final : server_task_result { {"usage", json { {"completion_tokens", n_decoded}, {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} + {"total_tokens", n_decoded + n_prompt_tokens}, + {"completion_tokens_details", json { + {"accepted_prediction_tokens", n_lookup_used }, + }} }}, {"id", oaicompat_cmpl_id} }; - + // extra fields for debugging purposes if (verbose) { res["__verbose"] = to_json_non_oaicompat(); @@ -811,6 +820,9 @@ struct server_task_result_cmpl_final : server_task_result { {"completion_tokens", n_decoded}, {"prompt_tokens", n_prompt_tokens}, {"total_tokens", n_decoded + n_prompt_tokens}, + {"completion_tokens_details", json { + {"accepted_prediction_tokens", n_lookup_used }, + }} }}, }; @@ -1235,16 +1247,22 @@ struct server_slot { int32_t n_ctx = 0; // context size per slot int32_t n_past = 0; int32_t n_decoded = 0; + int32_t n_lookup_used = 0; int32_t n_remaining = -1; int32_t i_batch = -1; int32_t n_predict = -1; // TODO: disambiguate from params.n_predict + // for "predicted outputs" + int32_t lookup_n_adaptive = 1; + int32_t lookup_index = 0; + // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_processed = 0; // input prompt tokens llama_tokens prompt_tokens; + llama_tokens prediction_tokens; size_t last_nl_pos = 0; @@ -1912,9 +1930,8 @@ struct server_context { slot.n_ctx = n_ctx_slot; slot.n_predict = params_base.n_predict; + slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); if (model_dft) { - slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { SRV_ERR("%s", "failed to create draft context\n"); @@ -2034,6 +2051,7 @@ struct server_context { slot.task_type = task.type; slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); + slot.prediction_tokens = std::move(task.prediction_tokens); if (!are_lora_equal(task.params.lora, slot.lora)) { // if lora is changed, we cannot reuse cached tokens @@ -2345,6 +2363,7 @@ struct server_context { res->n_decoded = slot.n_decoded; res->n_prompt_tokens = slot.n_prompt_tokens; res->n_tokens_cached = slot.n_past; + res->n_lookup_used = slot.n_lookup_used; res->has_new_line = slot.has_new_line; res->stopping_word = slot.stopping_word; res->stop = slot.stop; @@ -3217,6 +3236,137 @@ struct server_context { } } + // apply "predicted outputs" i.e. user-specified speculation + // using a simple lookup decoding method + for (auto & slot : slots) { + // don't use lookup if we are also using a draft model + if (slot.can_speculate() || !slot.is_processing() || slot.prediction_tokens.size() < 2) { + continue; + } + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + // adaptive speculation window: + // increase window size every time all drafted tokens were accepted, + // otherwise reset to zero + auto draft_start_pos = 1; + bool found = false; + // first look for a match from the expected position + SLT_DBG(slot, "Looking up prediction tokens at index %d/%d\n", (int) slot.lookup_index, (int) slot.prediction_tokens.size()); + if (slot.lookup_index > 0 && + slot.lookup_index < static_cast(slot.prediction_tokens.size()) && + slot.prediction_tokens[slot.lookup_index-1] == slot.sampled) { + found = true; + draft_start_pos = slot.lookup_index; + // TODO what is a good scaling law here? + // going for too large windows too fast will likely fail, + // but also too small windows in the beginning hurt perf + slot.lookup_n_adaptive = std::max(16, slot.lookup_n_adaptive*2); + } else { + // find first match in prediction_tokens + slot.lookup_n_adaptive = 1; // default + for (; draft_start_pos < static_cast(slot.prediction_tokens.size()); draft_start_pos++) { + if (slot.prediction_tokens[draft_start_pos-1] == slot.sampled) { + found = true; + break; + } + } + } + if (!found) continue; + + // we erase the accepted tokens later, so we're looking for the same position next time + // increment by one because the next token will be generated + slot.lookup_index = draft_start_pos + 1; + + llama_tokens draft = std::vector( + slot.prediction_tokens.begin() + draft_start_pos, + slot.prediction_tokens.end() + ); + + // determine the max draft that fits the current slot state + int n_draft_max = slot.lookup_n_adaptive; + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + + if (slot.n_remaining > 0) { + n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); + } + + n_draft_max = std::min(n_draft_max, static_cast(draft.size())); + // NOTE: we use speculative.n_max here as the upper limit, but + // in general we want to allow large drafts, as opposed to when + // using a draft model. But this is linked to `slot.batch_spec` + // size also. + n_draft_max = std::min(n_draft_max, slot.params.speculative.n_max); + + SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); + + draft.resize(n_draft_max); + + llama_token id = slot.sampled; + + // construct the speculation batch + common_batch_clear(slot.batch_spec); + common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + } + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + // TODO can we stream these? Would be nice to reduce jankiness in UIs + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + const auto n_accepted = ids.size() - 1; + slot.n_lookup_used += n_accepted; + + if (n_accepted > 0) { + // remove the prediction tokens that were used + the next token + // (because it will be generated) + slot.prediction_tokens.erase( + slot.prediction_tokens.begin() + draft_start_pos, + std::min( + slot.prediction_tokens.end(), + slot.prediction_tokens.begin() + draft_start_pos + n_accepted + 1 + ) + ); + if (n_accepted < draft.size()) { + // reset speculation as we didn't use the full draft + slot.lookup_n_adaptive = 1; + } + } + + for (size_t i = 0; i < ids.size(); ++i) { + // NOTE: we need to update these here to avoid stopping early + slot.n_past++; + slot.n_decoded++; + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + } + + slot.cache_tokens.push_back(id); + slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + + llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + + SLT_DBG(slot, "accepted %d/%d prediction tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); + } + // do speculative decoding for (auto & slot : slots) { if (!slot.is_processing() || !slot.can_speculate()) { @@ -3838,10 +3988,17 @@ int main(int argc, char ** argv) { try { const auto & prompt = data.at("prompt"); + const auto & prediction_obj = json_value(data, "prediction", json()); + const auto & prediction = json_value(prediction_obj, "content", std::string()); // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + std::vector tokenized_prediction; + if (!prediction.empty()) { + tokenized_prediction = tokenize_input_prompts(ctx_server.vocab, prediction, true, true); + } + tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(type); @@ -3850,6 +4007,10 @@ int main(int argc, char ** argv) { task.index = i; task.prompt_tokens = std::move(tokenized_prompts[i]); + + if (!tokenized_prediction.empty()) { + task.prediction_tokens = std::vector(tokenized_prediction[0].begin(), tokenized_prediction[0].end()); + } task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, diff --git a/examples/server/tests/unit/test_predicted_outputs.py b/examples/server/tests/unit/test_predicted_outputs.py new file mode 100644 index 0000000000000..aae456440d197 --- /dev/null +++ b/examples/server/tests/unit/test_predicted_outputs.py @@ -0,0 +1,62 @@ +import pytest +from utils import * + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.draft_max = 1024 + server.debug = True + + +def test_with_and_without_prediced_outputs(): + global server + server.start() + res = server.make_request("POST", "/v1/chat/completions", data={ + "messages": [{"role": "user", "content": "I believe the meaning of life is"}], + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + assert res.body["usage"]["completion_tokens_details"]["accepted_prediction_tokens"] == 0 + content_no_pred = res.body["choices"][0]["message"]["content"] + server.stop() + + server.start() + res = server.make_request("POST", "/v1/chat/completions", data={ + "messages": [{"role": "user", "content": "I believe the meaning of life is"}], + "temperature": 0.0, + "top_k": 1, + "prediction": {"content": '''"Here?" Annabyed. +"Okay, Annabyes!" Annabyed. +As Annagged, Annap came and said,'''} + }) + assert res.status_code == 200 + assert res.body["usage"]["completion_tokens_details"]["accepted_prediction_tokens"] == 54 + content_pred = res.body["choices"][0]["message"]["content"] + server.stop() + + assert content_no_pred == content_pred + + +@pytest.mark.parametrize("n_slots,n_requests", [ + (1, 2), + (2, 2), +]) +def test_multi_requests_parallel(n_slots: int, n_requests: int): + global server + server.n_slots = n_slots + server.start() + tasks = [] + for _ in range(n_requests): + res = server.make_request("POST", "/v1/chat/completions", data={ + "messages": [{"role": "user", "content": "I believe the meaning of life is"}], + "temperature": 0.0, + "top_k": 1, + "prediction": {"content": " believe the meaning of life is"} + }) + results = parallel_function_calls(tasks) + for res in results: + assert res.status_code == 200 + assert match_regex("(wise|kind|owl|answer)+", res.body["content"])