diff --git a/src/llama-context.cpp b/src/llama-context.cpp index c599801763181..5ac28f983027e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -45,6 +45,137 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t // llama_context_base // +class llama_graph_input_embd : public llama_graph_input_i { +public: + llama_graph_input_embd() = default; + virtual ~llama_graph_input_embd() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * tokens = nullptr; // I32 [n_batch] + ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] +}; + +void llama_graph_input_embd::set_input(const llama_ubatch * ubatch) { + if (ubatch->token) { + const int64_t n_tokens = ubatch->n_tokens; + + ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens)); + } + + if (ubatch->embd) { + const int64_t n_embd = embd->ne[0]; + const int64_t n_tokens = ubatch->n_tokens; + + ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd)); + } +} + +class llama_graph_input_attn_base : public llama_graph_input_attn_i { +public: + llama_graph_input_attn_base(const llama_hparams & hparams, const llama_cparams & cparams) : + hparams(hparams), + cparams(cparams) { + } + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * get_kq_mask() override { return kq_mask_cnv; } + + ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch] + ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch] + + const llama_hparams & hparams; + const llama_cparams & cparams; +}; + +void llama_graph_input_attn_base::set_input(const llama_ubatch * ubatch) { + if (kq_mask) { + if (cparams.causal_attn) { + const int64_t n_kv = ubatch->n_tokens; + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); + float * data = (float *) kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch->seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { + if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) { + if (hparams.use_alibi) { + f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); + } else { + f = 0.0f; + } + break; + } + } + + data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f; + } + } + } + } + } + } else { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + const int64_t n_stride = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); + + float * data = (float *) kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch->seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { + if (ubatch->seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); + } else { + f = 0.0f; + } + break; + } + } + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; + } + } + + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } + } + } + } + } + } +} + llama_context_base::llama_context_base( const llama_model & model, llama_context_params params, @@ -714,7 +845,8 @@ int llama_context_base::encode(llama_batch & inp_batch) { ggml_backend_sched_alloc_graph(sched.get(), gf); - input_set(ubatch); + res->set_inputs(&ubatch); + input_set(ubatch); // TODO: remove, tmp here, until all inputs are migrated outside the context const auto compute_status = graph_compute(gf, n_tokens > 1); switch (compute_status) { @@ -729,7 +861,7 @@ int llama_context_base::encode(llama_batch & inp_batch) { return -3; } - auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd; + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract embeddings if (t_embd) { @@ -870,7 +1002,8 @@ int llama_context_base::decode(llama_batch & inp_batch) { ggml_backend_sched_alloc_graph(sched.get(), gf); - input_set(ubatch); + res->set_inputs(&ubatch); + input_set(ubatch); // TODO: remove const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); if (compute_status != GGML_STATUS_SUCCESS) { @@ -885,11 +1018,11 @@ int llama_context_base::decode(llama_batch & inp_batch) { } } - auto * t_logits = cparams.embeddings ? nullptr : res.t_logits; - auto * t_embd = cparams.embeddings ? res.t_embd : nullptr; + auto * t_logits = cparams.embeddings ? nullptr : res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - if (t_embd && res.t_embd_pooled) { - t_embd = res.t_embd_pooled; + if (t_embd && res->get_embd_pooled()) { + t_embd = res->get_embd_pooled(); } // extract logits @@ -1002,19 +1135,6 @@ int64_t llama_context_base::n_pos_per_token() const { void llama_context_base::input_set(const llama_ubatch & ubatch) { const llama_hparams & hparams = model.hparams; - if (ubatch.token) { - const int64_t n_tokens = ubatch.n_tokens; - - ggml_backend_tensor_set(inp.tokens, ubatch.token, 0, n_tokens*ggml_element_size(inp.tokens)); - } - - if (ubatch.embd) { - const int64_t n_embd = hparams.n_embd; - const int64_t n_tokens = ubatch.n_tokens; - - ggml_backend_tensor_set(inp.embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(inp.embd)); - } - if (ubatch.pos && inp.pos) { const int64_t n_tokens = ubatch.n_tokens; @@ -1159,91 +1279,6 @@ void llama_context_base::input_set(const llama_ubatch & ubatch) { } } - if (inp.kq_mask) { - if (cparams.causal_attn) { - const int64_t n_kv = ubatch.n_tokens; - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - - GGML_ASSERT(ggml_backend_buffer_is_host(inp.kq_mask->buffer)); - float * data = (float *) inp.kq_mask->data; - - for (int h = 0; h < 1; ++h) { - for (int s1 = 0; s1 < n_seqs; ++s1) { - const llama_seq_id seq_id = ubatch.seq_id[s1][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const int32_t tj = s1*n_seq_tokens + j; - - for (int s0 = 0; s0 < n_seqs; ++s0) { - for (int i = 0; i < n_seq_tokens; ++i) { - const int32_t ti = s0*n_seq_tokens + i; - float f = -INFINITY; - - for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { - if (ubatch.seq_id[s0][s] == seq_id && ubatch.pos[ti] <= ubatch.pos[tj]) { - if (hparams.use_alibi) { - f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); - } else { - f = 0.0f; - } - break; - } - } - - data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f; - } - } - } - } - } - } else { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - const int64_t n_stride = ubatch.n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(inp.kq_mask->buffer)); - - float * data = (float *) inp.kq_mask->data; - - for (int h = 0; h < 1; ++h) { - for (int s1 = 0; s1 < n_seqs; ++s1) { - const llama_seq_id seq_id = ubatch.seq_id[s1][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const int32_t tj = s1*n_seq_tokens + j; - - for (int s0 = 0; s0 < n_seqs; ++s0) { - for (int i = 0; i < n_seq_tokens; ++i) { - const int32_t ti = s0*n_seq_tokens + i; - float f = -INFINITY; - - for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { - if (ubatch.seq_id[s0][s] == seq_id) { - if (hparams.use_alibi) { - f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); - } else { - f = 0.0f; - } - break; - } - } - - data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; - } - } - - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; - } - } - } - } - } - } - if (inp.pos_bucket) { const int64_t n_tokens = ubatch.n_tokens; @@ -1401,7 +1436,7 @@ ggml_cgraph * llama_context_base::graph_init() { return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); } -llama_graph_result llama_context_base::graph_build( +llama_graph_result_ptr llama_context_base::graph_build( ggml_context * ctx, ggml_cgraph * gf, const llama_ubatch & ubatch) { @@ -1604,21 +1639,24 @@ ggml_tensor * llama_context_base::build_rope_shift( } ggml_tensor * llama_context_base::build_inp_embd( - ggml_context * ctx0, - ggml_tensor * tok_embd, - const llama_ubatch & ubatch) { + llama_graph_result * res, + ggml_context * ctx0, + ggml_tensor * tok_embd, + const llama_ubatch & ubatch) const { const auto & hparams = model.hparams; const int64_t n_embd = hparams.n_embd; + auto inp = std::make_shared(); + struct ggml_tensor * inpL; if (ubatch.token) { - inp.tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); - //cb(inp.tokens, "inp_tokens", -1); - ggml_set_input(inp.tokens); + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + //cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); - inpL = ggml_get_rows(ctx0, tok_embd, inp.tokens); + inpL = ggml_get_rows(ctx0, tok_embd, inp->tokens); // apply lora for embedding tokens if needed for (const auto & lora : loras) { @@ -1632,15 +1670,15 @@ ggml_tensor * llama_context_base::build_inp_embd( struct ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat( ctx0, lw->b, // non-transposed lora_b - ggml_get_rows(ctx0, lw->a, inp.tokens) + ggml_get_rows(ctx0, lw->a, inp->tokens) ), scale); inpL = ggml_add(ctx0, inpL, inpL_delta); } } else { - inp.embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); - inpL = inp.embd; - ggml_set_input(inp.embd); + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); + inpL = inp->embd; + ggml_set_input(inp->embd); } // For Granite architecture @@ -1648,6 +1686,8 @@ ggml_tensor * llama_context_base::build_inp_embd( inpL = ggml_scale(ctx0, inpL, hparams.f_embedding_scale); } + res->add_input(std::move(inp)); + //cb(inpL, "inp_embd", -1); return inpL; @@ -1699,23 +1739,31 @@ ggml_tensor * llama_context_base::build_inp_cls( return inp.cls; } -void llama_context_base::build_attn_inp( - ggml_context * ctx0, - int32_t n_tokens, - bool causal, - bool swa) { +llama_graph_input_attn_ptr llama_context_base::build_attn_inp( + llama_graph_result * res, + ggml_context * ctx0, + int32_t n_tokens, + bool causal, + bool swa) const { + auto inp = std::make_shared(model.hparams, cparams); + // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch GGML_UNUSED(causal); GGML_UNUSED(swa); - inp.kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp_kq_mask, "KQ_mask", -1); - ggml_set_input(inp.kq_mask); + ggml_set_input(inp->kq_mask); + + inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; - inp.kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.kq_mask, GGML_TYPE_F16) : inp.kq_mask; + res->add_input(inp); + + return inp; } ggml_tensor * llama_context_base::build_attn( + llama_graph_input_attn_i * inp, ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * q_cur, @@ -1723,10 +1771,10 @@ ggml_tensor * llama_context_base::build_attn( ggml_tensor * v_cur, ggml_tensor * kq_b, float kq_scale, - int il) { + int il) const { GGML_UNUSED(il); - const auto & kq_mask = inp.kq_mask_cnv; + const auto & kq_mask = inp->get_kq_mask(); ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); //cb(q, "q", il); @@ -1751,7 +1799,7 @@ ggml_tensor * llama_context_base::build_attn_mha( ggml_tensor * kq_b, ggml_tensor * kq_mask, bool v_trans, - float kq_scale) { + float kq_scale) const { const auto & hparams = model.hparams; //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); @@ -2380,6 +2428,156 @@ size_t llama_context_base::state_seq_read_data(llama_io_read_i & io, llama_seq_i // llama_context_kv_self // +class llama_graph_input_attn_kv_self : public llama_graph_input_attn_i { +public: + llama_graph_input_attn_kv_self( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_unified * kv_self) : + hparams(hparams), + cparams(cparams), + kv_self(kv_self) { + } + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * get_kq_mask() override { return self_kq_mask_cnv; } + ggml_tensor * get_kq_mask_swa() override { return self_kq_mask_swa_cnv; } + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch] + + const llama_hparams & hparams; + const llama_cparams & cparams; + + const llama_kv_cache_unified * kv_self; +}; + +void llama_graph_input_attn_kv_self::set_input(const llama_ubatch * ubatch) { + if (self_kq_mask || self_kq_mask_swa) { + // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. + if (cparams.causal_attn) { + const int64_t n_kv = kv_self->n; + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + float * data = nullptr; + float * data_swa = nullptr; + + if (self_kq_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + data = (float *) self_kq_mask->data; + } + + if (self_kq_mask_swa) { + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); + data_swa = (float *) self_kq_mask_swa->data; + } + + // For causal attention, use only the previous KV cells + // of the correct sequence for each token of the ubatch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + for (int h = 0; h < 1; ++h) { + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos pos = ubatch->pos[s*n_seq_tokens + j]; + + for (int i = 0; i < n_kv; ++i) { + float f; + if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) { + f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -std::abs(kv_self->cells[i].pos - pos); + } else { + f = 0.0f; + } + } + + if (data) { + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + + // may need to cut off old tokens for sliding window + if (data_swa) { + if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + } + } + } + + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + + if (data_swa) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + } + } else { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + // when using kv cache, the mask needs to match the kv cache size + const int64_t n_stride = n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + + float * data = (float *) self_kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch->seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { + if (ubatch->seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); + } else { + f = 0.0f; + } + break; + } + } + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; + } + } + + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } + } + } + } + } + } +} + llama_context_kv_self::llama_context_kv_self( const llama_model & model, llama_context_params params, @@ -2593,7 +2791,8 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { ggml_backend_sched_alloc_graph(sched.get(), gf); - input_set(ubatch); + res->set_inputs(&ubatch); + input_set(ubatch); // TODO: remove const auto compute_status = graph_compute(gf, n_tokens > 1); switch (compute_status) { @@ -2608,7 +2807,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { return -3; } - auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd; + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract embeddings if (t_embd) { @@ -2831,7 +3030,8 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { ggml_backend_sched_alloc_graph(sched.get(), gf); - input_set(ubatch); + res->set_inputs(&ubatch); + input_set(ubatch); // TODO: remove const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); if (compute_status != GGML_STATUS_SUCCESS) { @@ -2861,11 +3061,11 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = cparams.embeddings ? nullptr : res.t_logits; - auto * t_embd = cparams.embeddings ? res.t_embd : nullptr; + auto * t_logits = cparams.embeddings ? nullptr : res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - if (t_embd && res.t_embd_pooled) { - t_embd = res.t_embd_pooled; + if (t_embd && res->get_embd_pooled()) { + t_embd = res->get_embd_pooled(); } // extract logits @@ -3009,127 +3209,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { // call base functionality llama_context_base::input_set(ubatch); - if (inp.self_kq_mask || inp.self_kq_mask_swa) { - // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. - if (cparams.causal_attn) { - const int64_t n_kv = kv_self->n; - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - - float * data = nullptr; - float * data_swa = nullptr; - - if (inp.self_kq_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask->buffer)); - data = (float *) inp.self_kq_mask->data; - } - - if (inp.self_kq_mask_swa) { - GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask_swa->buffer)); - data_swa = (float *) inp.self_kq_mask_swa->data; - } - - // For causal attention, use only the previous KV cells - // of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - for (int h = 0; h < 1; ++h) { - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const llama_pos pos = ubatch.pos[s*n_seq_tokens + j]; - - for (int i = 0; i < n_kv; ++i) { - float f; - if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) { - f = -INFINITY; - } else { - if (hparams.use_alibi) { - f = -std::abs(kv_self->cells[i].pos - pos); - } else { - f = 0.0f; - } - } - - if (data) { - data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } - - // may need to cut off old tokens for sliding window - if (data_swa) { - if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) { - f = -INFINITY; - } - data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } - } - } - } - - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } - } - } - - if (data_swa) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } - } - } - } - } else { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - // when using kv cache, the mask needs to match the kv cache size - const int64_t n_stride = n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask->buffer)); - - float * data = (float *) inp.self_kq_mask->data; - - for (int h = 0; h < 1; ++h) { - for (int s1 = 0; s1 < n_seqs; ++s1) { - const llama_seq_id seq_id = ubatch.seq_id[s1][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const int32_t tj = s1*n_seq_tokens + j; - - for (int s0 = 0; s0 < n_seqs; ++s0) { - for (int i = 0; i < n_seq_tokens; ++i) { - const int32_t ti = s0*n_seq_tokens + i; - float f = -INFINITY; - - for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { - if (ubatch.seq_id[s0][s] == seq_id) { - if (hparams.use_alibi) { - f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); - } else { - f = 0.0f; - } - break; - } - } - - data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; - } - } - - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; - } - } - } - } - } - } - if (inp.self_pos_bucket) { const int64_t n_tokens = ubatch.n_tokens; @@ -3173,37 +3252,45 @@ ggml_tensor * llama_context_kv_self::build_inp_pos_bucket( return inp.self_pos_bucket; } -void llama_context_kv_self::build_attn_inp( - ggml_context * ctx0, - int32_t n_tokens, - bool causal, - bool swa) { +llama_graph_input_attn_ptr llama_context_kv_self::build_attn_inp( + llama_graph_result * res, + ggml_context * ctx0, + int32_t n_tokens, + bool causal, + bool swa) const { + auto inp = std::make_shared(model.hparams, cparams, kv_self.get()); + const auto n_kv = kv_self->n; - inp.self_kq_mask = causal + inp->self_kq_mask = causal ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp.self_kq_mask, "KQ_mask", -1); - ggml_set_input(inp.self_kq_mask); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); - inp.self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.self_kq_mask, GGML_TYPE_F16) : inp.self_kq_mask; + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; if (swa) { const auto & hparams = model.hparams; GGML_ASSERT(hparams.n_swa > 0); - inp.self_kq_mask_swa = causal + inp->self_kq_mask_swa = causal ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp.self_kq_mask_swa, "KQ_mask_swa", -1); - ggml_set_input(inp.self_kq_mask_swa); + //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); + ggml_set_input(inp->self_kq_mask_swa); - inp.self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.self_kq_mask_swa, GGML_TYPE_F16) : inp.self_kq_mask_swa; + inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; } + + res->add_input(inp); + + return inp; } ggml_tensor * llama_context_kv_self::build_attn( + llama_graph_input_attn_i * inp, ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * q_cur, @@ -3211,7 +3298,7 @@ ggml_tensor * llama_context_kv_self::build_attn( ggml_tensor * v_cur, ggml_tensor * kq_b, float kq_scale, - int il) { + int il) const { const auto & hparams = model.hparams; const auto & n_ctx = cparams.n_ctx; @@ -3280,7 +3367,7 @@ ggml_tensor * llama_context_kv_self::build_attn( } }; - const auto & kq_mask = is_sliding ? inp.self_kq_mask_swa_cnv : inp.self_kq_mask_cnv; + const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask(); const auto n_kv = kv_self->n; @@ -3897,7 +3984,8 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) { ggml_backend_sched_alloc_graph(sched.get(), gf); - input_set(ubatch); + res->set_inputs(&ubatch); + input_set(ubatch); // TODO: remove const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); if (compute_status != GGML_STATUS_SUCCESS) { @@ -3927,11 +4015,11 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = cparams.embeddings ? nullptr : res.t_logits; - auto * t_embd = cparams.embeddings ? res.t_embd : nullptr; + auto * t_logits = cparams.embeddings ? nullptr : res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - if (t_embd && res.t_embd_pooled) { - t_embd = res.t_embd_pooled; + if (t_embd && res->get_embd_pooled()) { + t_embd = res->get_embd_pooled(); } // extract logits @@ -4604,7 +4692,8 @@ int llama_context_enc::encode(llama_batch & inp_batch) { ggml_backend_sched_alloc_graph(sched.get(), gf); - input_set(ubatch); + res->set_inputs(&ubatch); + input_set(ubatch); // TODO: remove const auto compute_status = graph_compute(gf, n_tokens > 1); switch (compute_status) { @@ -4619,7 +4708,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) { return -3; } - auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd; + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract embeddings if (t_embd) { @@ -4693,38 +4782,41 @@ int llama_context_enc::encode(llama_batch & inp_batch) { // llama_context_dec // -void llama_context_dec::reserve() { - // simulate full KV cache - cross->t_embd = nullptr; +class llama_graph_input_attn_dec : public llama_graph_input_attn_i { +public: + llama_graph_input_attn_dec( + llama_graph_input_attn_i * inp_kv_self, + const llama_cross * cross) : inp_kv_self(inp_kv_self), cross(cross) {} - llama_context_kv_self::reserve(); -} + void set_input(const llama_ubatch * ubatch) override; -void llama_context_dec::input_set(const llama_ubatch & ubatch) { - // call base functionality - llama_context_kv_self::input_set(ubatch); + ggml_tensor * get_kq_mask() override { return inp_kv_self->get_kq_mask(); } + ggml_tensor * get_kq_mask_swa() override { return inp_kv_self->get_kq_mask_swa(); } + ggml_tensor * get_kq_mask_cross() override { return cross_kq_mask_cnv; } - if (inp.cross_embd && cross->t_embd) { - assert(inp.cross_embd->type == GGML_TYPE_F32); + ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch] + ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch] - ggml_backend_tensor_set(inp.cross_embd, cross->v_embd, 0, ggml_nbytes(inp.cross_embd)); - } + llama_graph_input_attn_i * inp_kv_self = nullptr; + const llama_cross * cross = nullptr; +}; - if (inp.cross_kq_mask) { - const int64_t n_enc = inp.cross_kq_mask->ne[0]; - const int64_t n_tokens = ubatch.n_tokens; +void llama_graph_input_attn_dec::set_input(const llama_ubatch * ubatch) { + if (cross_kq_mask) { + const int64_t n_enc = cross_kq_mask->ne[0]; + const int64_t n_tokens = ubatch->n_tokens; - GGML_ASSERT(ggml_backend_buffer_is_host(inp.cross_kq_mask->buffer)); - GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing + GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer)); + GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing - float * data = (float *) inp.cross_kq_mask->data; + float * data = (float *) cross_kq_mask->data; for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_enc; ++i) { float f = -INFINITY; - for (int s = 0; s < ubatch.n_seq_id[j]; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[j][s]; + for (int s = 0; s < ubatch->n_seq_id[j]; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[j][s]; if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) { f = 0.0f; } @@ -4742,6 +4834,25 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) { } } +void llama_context_dec::reserve() { + // simulate full KV cache + cross->t_embd = nullptr; + + llama_context_kv_self::reserve(); +} + +void llama_context_dec::input_set(const llama_ubatch & ubatch) { + // call base functionality + llama_context_kv_self::input_set(ubatch); + + if (inp.cross_embd && cross->t_embd) { + assert(inp.cross_embd->type == GGML_TYPE_F32); + + ggml_backend_tensor_set(inp.cross_embd, cross->v_embd, 0, ggml_nbytes(inp.cross_embd)); + } + +} + ggml_cgraph * llama_context_dec::graph_init() { inp = {}; @@ -4769,22 +4880,30 @@ ggml_tensor * llama_context_dec::build_inp_cross_embd( return inp.cross_embd; } -void llama_context_dec::build_attn_inp( - ggml_context * ctx0, - int32_t n_tokens, - bool causal, - bool swa) { - llama_context_kv_self::build_attn_inp(ctx0, n_tokens, causal, swa); +llama_graph_input_attn_ptr llama_context_dec::build_attn_inp( + llama_graph_result * res, + ggml_context * ctx0, + int32_t n_tokens, + bool causal, + bool swa) const { + auto inp_kv_self = llama_context_kv_self::build_attn_inp(res, ctx0, n_tokens, causal, swa); + + auto inp = std::make_shared(inp_kv_self.get(), cross); const int32_t n_enc = cross->t_embd ? cross->t_embd->ne[1] : model.hparams.n_ctx_train; - inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - ggml_set_input(inp.cross_kq_mask); + inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(inp->cross_kq_mask); + + inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; + + res->add_input(inp); - inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.cross_kq_mask, GGML_TYPE_F16) : inp.cross_kq_mask; + return inp; } ggml_tensor * llama_context_dec::build_attn_cross( + llama_graph_input_attn_i * inp, ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * q_cur, @@ -4792,10 +4911,10 @@ ggml_tensor * llama_context_dec::build_attn_cross( ggml_tensor * v_cur, ggml_tensor * kq_b, float kq_scale, - int il) { + int il) const { GGML_UNUSED(il); - const auto & kq_mask = inp.cross_kq_mask_cnv; + const auto & kq_mask = inp->get_kq_mask_cross(); ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); //cb(q, "q", il); diff --git a/src/llama-context.h b/src/llama-context.h index f44652e2d1f18..0f248537eded3 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -251,22 +251,18 @@ class llama_context_base : public llama_context, public llama_graph_i { // when the compute graph is built, it creates the input tensors that it needs // the contents of the input tensors are set by the input_set() function + // TODO: remove, replace by llama_graph_input_i->set_input() virtual void input_set(const llama_ubatch & ubatch); private: + // TODO: remove, implement as llama_graph_input_xxx struct { // base input tensors - ggml_tensor * tokens; // I32 [n_batch] - ggml_tensor * embd; // F32 [n_embd, n_batch] ggml_tensor * pos; // I32 [n_batch] ggml_tensor * pos_bucket; // I32 [n_batch, n_batch] ggml_tensor * out_ids; // I32 [n_outputs] ggml_tensor * mean; // F32 [n_batch, n_batch] ggml_tensor * cls; // I32 [n_batch] - - // KQ mask input tensors - ggml_tensor * kq_mask; // F32 [n_tokens, n_batch] - ggml_tensor * kq_mask_cnv; // [n_tokens, n_batch] } inp; protected: @@ -292,7 +288,7 @@ class llama_context_base : public llama_context, public llama_graph_i { virtual ggml_cgraph * graph_init(); // TODO: add encode/decode graphs - virtual llama_graph_result graph_build( + virtual llama_graph_result_ptr graph_build( ggml_context * ctx, ggml_cgraph * gf, const llama_ubatch & ubatch); @@ -344,9 +340,10 @@ class llama_context_base : public llama_context, public llama_graph_i { ggml_backend_buffer * bbuf) override; ggml_tensor * build_inp_embd( - ggml_context * ctx0, - ggml_tensor * tok_embd, - const llama_ubatch & ubatch) override; + llama_graph_result * res, + ggml_context * ctx0, + ggml_tensor * tok_embd, + const llama_ubatch & ubatch) const override; ggml_tensor * build_inp_pos( ggml_context * ctx0, @@ -367,21 +364,23 @@ class llama_context_base : public llama_context, public llama_graph_i { ggml_context * ctx0, int32_t n_tokens) override; - void build_attn_inp( + llama_graph_input_attn_ptr build_attn_inp( + llama_graph_result * res, ggml_context * ctx0, int32_t n_tokens, bool causal, - bool swa) override; + bool swa) const override; ggml_tensor * build_attn( + llama_graph_input_attn_i * inp, ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, - float kq_scale, - int il) override; + float kq_scale, + int il) const override; protected: virtual ggml_tensor * build_attn_mha( @@ -393,7 +392,7 @@ class llama_context_base : public llama_context, public llama_graph_i { ggml_tensor * kq_b, ggml_tensor * kq_mask, bool v_trans, - float kq_scale); + float kq_scale) const; virtual ggml_tensor * build_inp_self_k_shift( ggml_context * ctx0); @@ -563,10 +562,6 @@ class llama_context_kv_self : public llama_context_base { private: struct { ggml_tensor * self_pos_bucket; // I32 [n_kv, n_batch] - ggml_tensor * self_kq_mask; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv; // [n_kv, n_batch] - ggml_tensor * self_kq_mask_swa; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_swa_cnv; // [n_kv, n_batch] ggml_tensor * self_k_shift; // I32 [kv_size] } inp; @@ -586,21 +581,23 @@ class llama_context_kv_self : public llama_context_base { ggml_context * ctx0, int32_t n_tokens) override; - void build_attn_inp( + llama_graph_input_attn_ptr build_attn_inp( + llama_graph_result * res, ggml_context * ctx0, int32_t n_tokens, bool causal, - bool swa) override; + bool swa) const override; ggml_tensor * build_attn( + llama_graph_input_attn_i * inp, ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, - float kq_scale, - int il) override; + float kq_scale, + int il) const override; protected: ggml_tensor * build_inp_self_k_shift(ggml_context * ctx0) override; @@ -786,8 +783,6 @@ class llama_context_dec : public llama_context_kv_self { private: struct { ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc] - ggml_tensor * cross_kq_mask; // F32 [n_outputs_enc, n_batch] - ggml_tensor * cross_kq_mask_cnv; // F32 [n_outputs_enc, n_batch] } inp; protected: @@ -800,13 +795,15 @@ class llama_context_dec : public llama_context_kv_self { ggml_tensor * build_inp_cross_embd( ggml_context * ctx0) override; - void build_attn_inp( + llama_graph_input_attn_ptr build_attn_inp( + llama_graph_result * res, ggml_context * ctx0, int32_t n_tokens, bool causal, - bool swa) override; + bool swa) const override; ggml_tensor * build_attn_cross( + llama_graph_input_attn_i * inp, ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * q_cur, @@ -814,7 +811,7 @@ class llama_context_dec : public llama_context_kv_self { ggml_tensor * v_cur, ggml_tensor * kq_b, float kq_scale, - int il) override; + int il) const override; public: llama_cross * cross = nullptr; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1e336e844ada0..549a42c53ba22 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2,17 +2,34 @@ #include "llama-impl.h" +ggml_tensor * llama_graph_input_attn_i::get_kq_mask() { + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + return nullptr; +} + +ggml_tensor * llama_graph_input_attn_i::get_kq_mask_swa() { + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + return nullptr; +} + +ggml_tensor * llama_graph_input_attn_i::get_kq_mask_cross() { + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + return nullptr; +} + llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {} ggml_tensor * llama_graph_i::build_attn( + llama_graph_input_attn_i * inp, ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, - float kq_scale, - int il) { + float kq_scale, + int il) const { + GGML_UNUSED(inp); GGML_UNUSED(ctx0); GGML_UNUSED(gf); GGML_UNUSED(q_cur); @@ -27,6 +44,7 @@ ggml_tensor * llama_graph_i::build_attn( } ggml_tensor * llama_graph_i::build_attn_cross( + llama_graph_input_attn_i * inp, ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * q_cur, @@ -34,7 +52,8 @@ ggml_tensor * llama_graph_i::build_attn_cross( ggml_tensor * v_cur, ggml_tensor * kq_b, float kq_scale, - int il) { + int il) const { + GGML_UNUSED(inp); GGML_UNUSED(ctx0); GGML_UNUSED(gf); GGML_UNUSED(q_cur); diff --git a/src/llama-graph.h b/src/llama-graph.h index 28e8a563067db..a6a9ef00ca860 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include // note: do not add high-level objects here, such as llama_context, llama_kv_cache, etc. // not sure about llama_batch/llama_sbatch yet @@ -9,6 +11,7 @@ struct ggml_cgraph; struct ggml_context; struct ggml_tensor; struct ggml_backend_buffer; + struct llama_ubatch; enum llama_graph_type { @@ -17,13 +20,78 @@ enum llama_graph_type { LLAMA_GRAPH_TYPE_DECODER, }; -struct llama_graph_result { +// +// llama_graph_input +// + +class llama_graph_input_i { +public: + virtual ~llama_graph_input_i() = default; + + virtual void set_input(const llama_ubatch * ubatch) = 0; +}; + +using llama_graph_input_ptr = std::shared_ptr; + +class llama_graph_input_attn_i : public llama_graph_input_i { +public: + virtual ~llama_graph_input_attn_i() = default; + + virtual ggml_tensor * get_kq_mask(); + virtual ggml_tensor * get_kq_mask_swa(); + virtual ggml_tensor * get_kq_mask_cross(); +}; + +using llama_graph_input_attn_ptr = std::shared_ptr; + +// +// llama_graph_result +// + +class llama_graph_result_i { +public: + virtual ~llama_graph_result_i() = default; + + virtual ggml_tensor * get_logits() = 0; + virtual ggml_tensor * get_embd() = 0; + virtual ggml_tensor * get_embd_pooled() = 0; + + virtual void set_inputs(const llama_ubatch * ubatch) = 0; +}; + +using llama_graph_result_ptr = std::unique_ptr; + +class llama_graph_result : public llama_graph_result_i { +public: + llama_graph_result() = default; + virtual ~llama_graph_result() = default; + + ggml_tensor * get_logits() override { return t_logits; } + ggml_tensor * get_embd() override { return t_embd; } + ggml_tensor * get_embd_pooled() override { return t_embd_pooled; } + + void set_inputs(const llama_ubatch * ubatch) override { + for (auto & input : inputs) { + input->set_input(ubatch); + } + } + + void add_input(llama_graph_input_ptr && input) { + inputs.emplace_back(std::move(input)); + } + // important graph nodes ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + + std::vector inputs; }; +// +// llama_graph +// + // TODO: can become more granular in the future class llama_graph_i { public: @@ -75,9 +143,10 @@ class llama_graph_i { // graph build API (context-specific) virtual ggml_tensor * build_inp_embd( + llama_graph_result * res, ggml_context * ctx0, ggml_tensor * tok_embd, - const llama_ubatch & ubatch) = 0; + const llama_ubatch & ubatch) const = 0; // note these methods will become const, i.e. they don't mutate the llama_context that implements them virtual ggml_tensor * build_inp_pos( ggml_context * ctx0, @@ -98,23 +167,26 @@ class llama_graph_i { ggml_context * ctx0, int32_t n_tokens) = 0; - virtual void build_attn_inp( + virtual llama_graph_input_attn_ptr build_attn_inp( + llama_graph_result * res, ggml_context * ctx0, int32_t n_tokens, bool causal, - bool swa) = 0; + bool swa) const = 0; virtual ggml_tensor * build_attn( + llama_graph_input_attn_i * inp, ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, - float kq_scale, - int il); + float kq_scale, + int il) const; virtual ggml_tensor * build_attn_cross( + llama_graph_input_attn_i * inp, ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * q_cur, @@ -122,7 +194,7 @@ class llama_graph_i { ggml_tensor * v_cur, ggml_tensor * kq_b, float kq_scale, - int il); + int il) const; virtual ggml_tensor * build_inp_cross_embd( ggml_context * ctx0); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 25a705c657cd9..b6adbb1a1bbed 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2,7 +2,6 @@ #include "llama-impl.h" #include "llama-mmap.h" -#include "llama-graph.h" #include "llama-batch.h" #include "llama-cparams.h" #include "llama-model-loader.h" @@ -3853,7 +3852,7 @@ struct llm_build_context { ggml_context * ctx0 = nullptr; llama_graph_i * lgf = nullptr; - llama_graph_result res; + std::unique_ptr res; // TODO: consider making the entire interface noexcept llm_build_context( @@ -3892,7 +3891,8 @@ struct llm_build_context { pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), ctx0 (ctx), - lgf (lgf) { + lgf (lgf), + res (std::make_unique()) { } // TODO: tmp @@ -3902,7 +3902,7 @@ struct llm_build_context { // TODO: tmp struct ggml_tensor * build_inp_embd(struct ggml_tensor * tok_embd) { - struct ggml_tensor * inpL = lgf->build_inp_embd(ctx0, tok_embd, ubatch); + struct ggml_tensor * inpL = lgf->build_inp_embd(res.get(), ctx0, tok_embd, ubatch); cb(inpL, "inp_embd", -1); return inpL; @@ -4259,15 +4259,16 @@ struct llm_build_context { } struct ggml_tensor * build_attn( - struct ggml_cgraph * gf, - struct ggml_tensor * wo, - struct ggml_tensor * wo_b, - struct ggml_tensor * q_cur, - struct ggml_tensor * k_cur, - struct ggml_tensor * v_cur, - int32_t n_tokens, // TODO: remove - float kq_scale, - int il) { + llama_graph_input_attn_i * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + int32_t n_tokens, // TODO: remove + float kq_scale, + int il) { GGML_UNUSED(n_tokens); // these nodes are added to the graph together so that they are not reordered @@ -4276,7 +4277,7 @@ struct llm_build_context { ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - ggml_tensor * cur = lgf->build_attn(ctx0, gf, q_cur, k_cur, v_cur, nullptr, kq_scale, il); + ggml_tensor * cur = lgf->build_attn(inp, ctx0, gf, q_cur, k_cur, v_cur, nullptr, kq_scale, il); cb(cur, "kqv_out", il); if (wo) { @@ -4295,15 +4296,16 @@ struct llm_build_context { } struct ggml_tensor * build_attn_cross( - struct ggml_cgraph * gf, - struct ggml_tensor * wo, - struct ggml_tensor * wo_b, - struct ggml_tensor * q_cur, - struct ggml_tensor * k_cur, - struct ggml_tensor * v_cur, - int32_t n_tokens, // TODO: remove - float kq_scale, - int il) { + llama_graph_input_attn_i * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + int32_t n_tokens, // TODO: remove + float kq_scale, + int il) { GGML_UNUSED(n_tokens); // these nodes are added to the graph together so that they are not reordered @@ -4312,7 +4314,7 @@ struct llm_build_context { ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - ggml_tensor * cur = lgf->build_attn_cross(ctx0, gf, q_cur, k_cur, v_cur, nullptr, kq_scale, il); + ggml_tensor * cur = lgf->build_attn_cross(inp, ctx0, gf, q_cur, k_cur, v_cur, nullptr, kq_scale, il); cb(cur, "kqv_out", il); if (wo) { @@ -4331,16 +4333,17 @@ struct llm_build_context { } struct ggml_tensor * build_attn_with_kq_b( - struct ggml_cgraph * gf, - struct ggml_tensor * wo, - struct ggml_tensor * wo_b, - struct ggml_tensor * q_cur, - struct ggml_tensor * k_cur, - struct ggml_tensor * v_cur, - struct ggml_tensor * kq_b, - int32_t n_tokens, // TODO: remove - float kq_scale, - int il) { + llama_graph_input_attn_i * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + int32_t n_tokens, // TODO: remove + float kq_scale, + int il) { GGML_UNUSED(n_tokens); // these nodes are added to the graph together so that they are not reordered @@ -4349,7 +4352,7 @@ struct llm_build_context { ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - ggml_tensor * cur = lgf->build_attn(ctx0, gf, q_cur, k_cur, v_cur, kq_b, kq_scale, il); + ggml_tensor * cur = lgf->build_attn(inp, ctx0, gf, q_cur, k_cur, v_cur, kq_b, kq_scale, il); cb(cur, "kqv_out", il); if (wo) { @@ -4397,7 +4400,7 @@ struct llm_build_context { } void append_pooling(struct ggml_cgraph * gf) { - struct ggml_tensor * inp = res.t_embd; + struct ggml_tensor * inp = res->t_embd; //// find result_norm tensor for input //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) { @@ -4457,7 +4460,7 @@ struct llm_build_context { } cb(cur, "result_embd_pooled", -1); - res.t_embd_pooled = cur; + res->t_embd_pooled = cur; ggml_build_forward_expand(gf, cur); } @@ -4495,7 +4498,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; for (int il = 0; il < n_layer; ++il) { @@ -4548,7 +4551,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, kq_scale, il); } @@ -4626,7 +4629,7 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); @@ -4637,7 +4640,7 @@ struct llm_build_context { } cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -4656,7 +4659,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; for (int il = 0; il < n_layer; ++il) { @@ -4720,7 +4723,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, kq_scale, il); } @@ -4782,7 +4785,7 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); @@ -4793,7 +4796,7 @@ struct llm_build_context { } cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -4812,7 +4815,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr; - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -4856,7 +4859,7 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -4903,13 +4906,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -4928,7 +4931,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -4962,7 +4965,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -5007,13 +5010,13 @@ struct llm_build_context { cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -5033,7 +5036,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * attn_norm; @@ -5084,7 +5087,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -5129,12 +5132,12 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -5156,7 +5159,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -5206,7 +5209,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f, il); } @@ -5277,7 +5280,7 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); @@ -5288,7 +5291,7 @@ struct llm_build_context { cur = ggml_scale(ctx0, cur, 0.5773502691896257f); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -5308,7 +5311,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -5353,7 +5356,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -5405,13 +5408,13 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -5430,7 +5433,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); struct ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); cb(pos, "pos_embd", -1); @@ -5463,7 +5466,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -5511,12 +5514,12 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -5531,7 +5534,7 @@ struct llm_build_context { inpL = build_inp_embd(model.tok_embd); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -5558,7 +5561,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -5605,13 +5608,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -5645,7 +5648,7 @@ struct llm_build_context { inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); cb(inpL, "inp_norm", -1); - lgf->build_attn_inp(ctx0, n_tokens, false, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, false, false); // iterate layers for (int il = 0; il < n_layer; ++il) { @@ -5710,7 +5713,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); @@ -5774,7 +5777,7 @@ struct llm_build_context { cur = inpL; cb(cur, "result_embd", -1); - res.t_embd = cur; + res->t_embd = cur; ggml_build_forward_expand(gf, cur); } @@ -5790,7 +5793,7 @@ struct llm_build_context { inpL = build_inp_embd(model.tok_embd); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); inpL = build_norm(inpL, model.tok_norm, @@ -5823,7 +5826,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -5871,12 +5874,12 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -5893,7 +5896,7 @@ struct llm_build_context { inpL = build_inp_embd(model.tok_embd); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); if (model.pos_embd) { // inp_pos - contains the positions @@ -5956,13 +5959,13 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } else { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6012,12 +6015,12 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -6035,7 +6038,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { @@ -6108,7 +6111,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6162,13 +6165,13 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -6186,7 +6189,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -6228,7 +6231,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6275,13 +6278,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -6300,7 +6303,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -6343,7 +6346,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6388,13 +6391,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -6413,7 +6416,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); @@ -6461,7 +6464,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6506,13 +6509,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -6531,7 +6534,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -6574,7 +6577,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6651,13 +6654,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -6678,7 +6681,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { attn_norm_output = build_norm(inpL, @@ -6733,7 +6736,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f, il); } @@ -6773,7 +6776,7 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output_no_bias", -1); @@ -6781,7 +6784,7 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, model.output_b); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -6801,7 +6804,7 @@ struct llm_build_context { struct ggml_tensor * inp_pos = build_inp_pos(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - lgf->build_attn_inp(ctx0, n_tokens, true, true); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, true); for (int il = 0; il < n_layer; ++il) { auto * residual = inpL; @@ -6856,7 +6859,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f, il); } @@ -6916,7 +6919,7 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); @@ -6926,7 +6929,7 @@ struct llm_build_context { } cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -6945,7 +6948,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { @@ -6981,7 +6984,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7025,13 +7028,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -7051,7 +7054,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); cb(pos, "pos_embd", -1); @@ -7084,7 +7087,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7132,12 +7135,12 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -7157,7 +7160,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { cur = build_norm(inpL, @@ -7196,7 +7199,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7244,12 +7247,12 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -7268,7 +7271,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -7317,7 +7320,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7362,13 +7365,13 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -7387,7 +7390,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -7436,7 +7439,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7481,13 +7484,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -7515,7 +7518,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -7630,7 +7633,7 @@ struct llm_build_context { struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(k_states, "k_states", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, q_states, k_states, v_states, n_tokens, kq_scale, il); } @@ -7686,7 +7689,7 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head scaling const float scale_lmhead = float(n_embd_base)/float(n_embd); @@ -7697,7 +7700,7 @@ struct llm_build_context { cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -7716,7 +7719,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { // norm @@ -7752,7 +7755,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f, il); } @@ -7799,13 +7802,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -7824,7 +7827,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, true); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, true); for (int il = 0; il < n_layer; ++il) { // norm @@ -7866,7 +7869,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f, il); } @@ -7923,7 +7926,7 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); @@ -7934,7 +7937,7 @@ struct llm_build_context { cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -7954,7 +7957,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -8003,7 +8006,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8049,13 +8052,13 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -8103,13 +8106,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -8129,7 +8132,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { @@ -8203,7 +8206,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8247,7 +8250,7 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); @@ -8257,7 +8260,7 @@ struct llm_build_context { } cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -8277,7 +8280,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, true); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, true); // sliding window switch pattern const int32_t sliding_window_pattern = 4; @@ -8338,7 +8341,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); } - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8377,7 +8380,7 @@ struct llm_build_context { cur = build_norm(cur, model.output_norm, NULL, LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); @@ -8387,7 +8390,7 @@ struct llm_build_context { } cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -8412,7 +8415,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -8461,7 +8464,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, nullptr, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8507,13 +8510,13 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -8532,7 +8535,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -8576,7 +8579,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur_rope", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8627,13 +8630,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -8656,7 +8659,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -8704,7 +8707,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur_rope", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8754,13 +8757,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -8777,7 +8780,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { const int64_t n_head = hparams.n_head(il); @@ -8834,7 +8837,7 @@ struct llm_build_context { Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens); cb(Qcur, "Vcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8881,12 +8884,12 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -8905,7 +8908,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { cur = build_norm(inpL, @@ -8944,7 +8947,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -9025,12 +9028,12 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -9049,7 +9052,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -9086,7 +9089,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -9154,13 +9157,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -9179,7 +9182,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -9233,7 +9236,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, kq_scale, il); } @@ -9309,13 +9312,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -9342,7 +9345,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -9461,7 +9464,7 @@ struct llm_build_context { struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(k_states, "k_states", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, q_states, k_states, v_states, n_tokens, kq_scale, il); } @@ -9536,13 +9539,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = ggml_mul_mat(ctx0, model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -9560,7 +9563,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -9619,7 +9622,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, NULL, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); @@ -9687,14 +9690,14 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head // FIXME: do not use model.tok_embd directly, duplicate as model.output cur = build_lora_mm(model.tok_embd, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -9711,7 +9714,7 @@ struct llm_build_context { struct ggml_tensor * pos_bucket_enc = build_pos_bucket(); - lgf->build_attn_inp(ctx0, n_tokens, false, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, false, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -9740,7 +9743,7 @@ struct llm_build_context { struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; struct ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); - cur = build_attn_with_kq_b(gf, + cur = build_attn_with_kq_b(inp_attn.get(), gf, model.layers[il].wo_enc, nullptr, Qcur, Kcur, Vcur, kq_b, n_tokens, 1.0f, il); cb(cur, "kqv_out", il); @@ -9793,7 +9796,7 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; ggml_build_forward_expand(gf, cur); } @@ -9814,7 +9817,7 @@ struct llm_build_context { const int64_t n_outputs_enc = embd_enc->ne[1]; - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -9843,7 +9846,7 @@ struct llm_build_context { struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; struct ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); - cur = build_attn_with_kq_b(gf, + cur = build_attn_with_kq_b(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, kq_b, n_tokens, 1.0f, il); cb(cur, "kqv_out", il); @@ -9875,7 +9878,7 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc); - cur = build_attn_cross(gf, + cur = build_attn_cross(inp_attn.get(), gf, model.layers[il].wo_cross, nullptr, Qcur, Kcur, Vcur, n_tokens, 1.0f, il); cb(cur, "kqv_out", il); @@ -9955,13 +9958,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -9977,7 +9980,7 @@ struct llm_build_context { inpL = build_inp_embd(model.tok_embd); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { cur = build_norm(inpL, @@ -10004,7 +10007,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/float(n_embd_head), il); } @@ -10047,12 +10050,12 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -10071,7 +10074,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -10132,7 +10135,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur_rope", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); @@ -10177,12 +10180,12 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -10201,7 +10204,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -10251,7 +10254,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -10297,13 +10300,13 @@ struct llm_build_context { LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -10322,7 +10325,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -10374,7 +10377,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -10420,13 +10423,13 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -10513,12 +10516,12 @@ struct llm_build_context { cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -10597,12 +10600,12 @@ struct llm_build_context { cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -10627,7 +10630,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false); + auto inp_attn = lgf->build_attn_inp(res.get(), ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -10696,7 +10699,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = build_attn(gf, + cur = build_attn(inp_attn.get(), gf, model.layers[il].wo, nullptr, Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); @@ -10757,7 +10760,7 @@ struct llm_build_context { LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); - res.t_embd = cur; + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); @@ -10777,7 +10780,7 @@ struct llm_build_context { cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx); cb(cur, "result_output", -1); - res.t_logits = cur; + res->t_logits = cur; ggml_build_forward_expand(gf, cur); } @@ -10927,13 +10930,13 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, model.output_b); cb(cur, "result_embd", -1); - res.t_embd = cur; + res->t_embd = cur; ggml_build_forward_expand(gf, cur); } }; -llama_graph_result llama_model::build_graph( +llama_graph_result_ptr llama_model::build_graph( ggml_context * ctx, ggml_cgraph * gf, llama_graph_i * lgf, @@ -11166,7 +11169,7 @@ llama_graph_result llama_model::build_graph( llm.append_pooling(gf); } - return llm.res; + return std::move(llm.res); } // diff --git a/src/llama-model.h b/src/llama-model.h index 447fc0d0576d6..2d64c0d242c09 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -3,6 +3,7 @@ #include "llama.h" #include "llama-arch.h" #include "llama-hparams.h" +#include "llama-graph.h" #include "llama-vocab.h" #include @@ -10,11 +11,9 @@ #include #include -class llama_graph_i; struct llama_cparams; struct llama_ubatch; struct llama_model_loader; -struct llama_graph_result; // available models enum llm_type { @@ -367,7 +366,7 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; // TODO: add encode/decode graphs - llama_graph_result build_graph( + llama_graph_result_ptr build_graph( ggml_context * ctx, ggml_cgraph * gf, llama_graph_i * lgf,