From aa8900bb07551e38be18f01f4f53739451d1befc Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Tue, 6 Feb 2024 20:40:06 -0500 Subject: [PATCH] Refactor QLD Make the implementation cleaner and compute term component only once per term scorer instead of for each score. Signed-off-by: Michal Siedlaczek --- include/pisa/scorer/qld.hpp | 18 ++++++++++-------- include/pisa/scorer/quantized.hpp | 4 +++- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/include/pisa/scorer/qld.hpp b/include/pisa/scorer/qld.hpp index 610cca2a..dc3ea19b 100644 --- a/include/pisa/scorer/qld.hpp +++ b/include/pisa/scorer/qld.hpp @@ -23,14 +23,16 @@ struct qld: public index_scorer { qld(const Wand& wdata, const float mu) : index_scorer(wdata), m_mu(mu) {} term_scorer_t term_scorer(uint64_t term_id) const override { - auto s = [&, term_id](uint32_t doc, uint32_t freq) { - float numerator = 1 - + freq - / (this->m_mu - * ((float)this->m_wdata.term_occurrence_count(term_id) - / this->m_wdata.collection_len())); - float denominator = this->m_mu / (this->m_wdata.doc_len(doc) + this->m_mu); - return std::max(0.F, std::log(numerator) + std::log(denominator)); + float mu = this->m_mu; + float collection_len = this->m_wdata.collection_len(); + float term_occurrences = this->m_wdata.term_occurrence_count(term_id); + float term_component = collection_len / (mu * term_occurrences); + + auto s = [this, mu, term_component, term_id](uint32_t doc, uint32_t freq) { + float doclen = this->m_wdata.doc_len(doc); + float a = std::log(mu / (doclen + mu)); + float b = std::log1p(freq * term_component); + return std::max(0.F, a + b); }; return s; } diff --git a/include/pisa/scorer/quantized.hpp b/include/pisa/scorer/quantized.hpp index 9b87e36f..c0afefc8 100644 --- a/include/pisa/scorer/quantized.hpp +++ b/include/pisa/scorer/quantized.hpp @@ -37,7 +37,9 @@ class QuantizingScorer { -> std::function { return [this, scorer = m_scorer->term_scorer(term_id)](std::uint32_t doc, std::uint32_t freq) { - return this->m_quantizer(scorer(doc, freq)); + auto score = scorer(doc, freq); + assert(score >= 0.0); + return this->m_quantizer(score); }; } };