Skip to content

Commit

Permalink
Suppress low-confidence SVs in strict mode (#3768)
Browse files Browse the repository at this point in the history
Currently, we have a low-confidence threshold (below which we print a
message) of `0.7`, and a drop threshold (below which SVs are dropped) of
`0.5` for SVs.

The main change in this PR is, for`strict` mode, we increase the drop
threshold to `0.7`. This is done by plumbing in the "threshold" value to
the NL server, depending on the setting of `mode=strict`.

Importantly, the `strict` mode is:
* supported only for `api/data` API that nodejs uses
* incompatible with LLM.

Concretely, setting of `mode=strict` implies: `sv_threshold=0.7`,
`detector=heuristic`, `use_default_place=False`

Additionally,
* Deprecate `udp=` param
* Add support for `mode=strict` to NL app, since that exercises the
`api/data` flow
* Move tests from `explore_test.py` to `nl_test.py`, also for the same
reason.
  • Loading branch information
pradh authored Nov 7, 2023
1 parent 2d1945f commit 0134730
Show file tree
Hide file tree
Showing 27 changed files with 263 additions and 390 deletions.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"console": "internalConsole"
},
{
"name": "Flask (Custom with NL)",
"name": "Flask (Custom DC with NL)",
"type": "python",
"justMyCode": false,
"request": "launch",
Expand Down
11 changes: 5 additions & 6 deletions nl_server/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@
_HIGHEST_SCORE = 1.0
_INIT_SCORE = (_HIGHEST_SCORE + 0.1)

# Scores below this are ignored.
_SV_SCORE_THRESHOLD = 0.5

_NUM_CANDIDATES_PER_NSPLIT = 3

# Number of matches to find within the SV index.
Expand Down Expand Up @@ -140,6 +137,7 @@ def _search_embeddings(self,
#
def detect_svs(self,
orig_query: str,
threshold: float = constants.SV_SCORE_DEFAULT_THRESHOLD,
skip_multi_sv: bool = False) -> Dict[str, Union[Dict, List]]:
# Remove all stop-words.
query_monovar = utils.remove_stop_words(orig_query,
Expand All @@ -153,7 +151,7 @@ def detect_svs(self,
# Try to detect multiple SVs. Use the original query so that
# the logic can rely on stop-words like `vs`, `and`, etc as hints
# for SV delimiters.
result_multivar = self._detect_multiple_svs(orig_query)
result_multivar = self._detect_multiple_svs(orig_query, threshold)
multi_sv = vars.multivar_candidates_to_dict(result_multivar)

# TODO: Rename SV_to_Sentences for consistency.
Expand All @@ -168,7 +166,8 @@ def detect_svs(self,
# Detects one or more SVs from the query.
# TODO: Fix the query upstream to ensure the punctuations aren't stripped.
#
def _detect_multiple_svs(self, query: str) -> vars.MultiVarCandidates:
def _detect_multiple_svs(self, query: str,
threshold: float) -> vars.MultiVarCandidates:
#
# Prepare a combination of query-sets.
#
Expand Down Expand Up @@ -234,7 +233,7 @@ def _detect_multiple_svs(self, query: str) -> vars.MultiVarCandidates:
total += score
candidate.parts.append(part)

if lowest < _SV_SCORE_THRESHOLD:
if lowest < threshold:
# A query-part's best SV did not cross our score threshold,
# so drop this candidate.
continue
Expand Down
15 changes: 14 additions & 1 deletion nl_server/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from nl_server import config
from nl_server import loader
from shared.lib.constants import SV_SCORE_DEFAULT_THRESHOLD

bp = Blueprint('main', __name__, url_prefix='/')

Expand All @@ -45,12 +46,24 @@ def search_sv():
idx = str(escape(request.args.get('idx', config.DEFAULT_INDEX_TYPE)))
if not idx:
idx = config.DEFAULT_INDEX_TYPE

threshold = escape(request.args.get('threshold'))
if threshold:
try:
threshold = float(threshold)
except Exception:
logging.error(f'Found non-float threshold value: {threshold}')
threshold = SV_SCORE_DEFAULT_THRESHOLD
else:
threshold = SV_SCORE_DEFAULT_THRESHOLD

skip_multi_sv = False
if request.args.get('skip_multi_sv'):
skip_multi_sv = True

try:
nl_embeddings = current_app.config[config.NL_EMBEDDINGS_KEY].get(idx)
return json.dumps(nl_embeddings.detect_svs(query, skip_multi_sv))
return json.dumps(nl_embeddings.detect_svs(query, threshold, skip_multi_sv))
except Exception as e:
logging.error(f'Embeddings-based SV detection failed with error: {e}')
return json.dumps({
Expand Down
5 changes: 2 additions & 3 deletions nl_server/tests/custom_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from parameterized import parameterized

from nl_server import config
from nl_server import embeddings
from nl_server import embeddings_store as store
from nl_server import gcs
from shared.lib.constants import SV_SCORE_DEFAULT_THRESHOLD
from shared.lib.gcs import TEMP_DIR

_test_data = os.path.join(os.path.dirname(os.path.abspath(__file__)),
Expand Down Expand Up @@ -97,7 +96,7 @@ def test_queries(self, dc, query, index, expected):
if idx:
got = idx.detect_svs(query)
for i in range(len(got['SV'])):
if got['CosineScore'][i] >= embeddings._SV_SCORE_THRESHOLD:
if got['CosineScore'][i] >= SV_SCORE_DEFAULT_THRESHOLD:
trimmed_svs.append(got['SV'][i])

if not expected:
Expand Down
366 changes: 68 additions & 298 deletions server/config/subject_page_pb2.py

Large diffs are not rendered by default.

27 changes: 6 additions & 21 deletions server/integration_tests/explore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,10 @@

class ExploreTest(NLWebServerTestCase):

def run_fulfillment(self,
test_dir,
req_json,
failure='',
test='',
i18n='',
udp=''):
resp = requests.post(
self.get_server_url() +
f'/api/explore/fulfill?test={test}&i18n={i18n}&udp={udp}',
json=req_json).json()
def run_fulfillment(self, test_dir, req_json, failure='', test='', i18n=''):
resp = requests.post(self.get_server_url() +
f'/api/explore/fulfill?test={test}&i18n={i18n}',
json=req_json).json()
self.handle_response(json.dumps(req_json), resp, test_dir, '', failure)

def run_detection(self,
Expand Down Expand Up @@ -71,13 +64,12 @@ def run_detect_and_fulfill(self,
dc='',
failure='',
test='',
i18n='',
udp=''):
i18n=''):
ctx = {}
for (index, q) in enumerate(queries):
resp = requests.post(
self.get_server_url() +
f'/api/explore/detect-and-fulfill?q={q}&test={test}&i18n={i18n}&udp={udp}',
f'/api/explore/detect-and-fulfill?q={q}&test={test}&i18n={i18n}',
json={
'contextHistory': ctx,
'dc': dc,
Expand Down Expand Up @@ -410,10 +402,3 @@ def test_e2e_fallbacks(self):
# to the place (SC county) to its state (CA).
'auto thefts in tracts of santa clara county'
])

def test_e2e_default_place(self):
self.run_detect_and_fulfill('e2e_default_place', [
'what does a diet for diabetes look like?',
'how to earn money online without investment'
],
udp='false')
38 changes: 32 additions & 6 deletions server/integration_tests/nl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,35 @@ def test_sdg(self):
['how many wise asses live in sunnyvale?'],
failure='could not complete')

def test_strict(self):
self.run_sequence('strict', [
'how do i build and construct a house and sell it in california with low income',
'tell me asian california population with low income'
],
mode='strict')
def test_strict_multi_verb(self):
self.run_sequence(
'strict_multi_verb',
[
# This query should return empty results in strict mode.
'how do i build and construct a house and sell it in california with low income',
# This query should be fine.
'tell me asian california population with low income',
],
mode='strict',
expected_detectors=[
'Heuristic Based',
'Heuristic Based',
])

def test_strict_default_place(self):
self.run_sequence(
'strict_default_place',
[
# These queries do not have a default place, so should fail.
'what does a diet for diabetes look like?',
'how to earn money online without investment',
# This query should return empty result because we don't
# return low-confidence results.
'number of headless drivers in california',
],
mode='strict',
expected_detectors=[
'Heuristic Based',
'Heuristic Based',
'Heuristic Based',
])
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@
}
],
"relatedThings": {},
"svSource": "CURRENT_QUERY",
"svSource": "UNKNOWN",
"userMessage": ""
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{
"config": {
"categories": [
{
"blocks": [
{
"columns": [
{
"tiles": [
{
"type": "PLACE_OVERVIEW"
}
]
}
]
}
]
}
],
"metadata": {
"placeDcid": [
"geoId/06"
]
}
},
"context": {},
"debug": {},
"pastSourceContext": "",
"place": {
"dcid": "geoId/06",
"name": "California",
"place_type": "State"
},
"placeFallback": {},
"placeSource": "CURRENT_QUERY",
"places": [
{
"dcid": "geoId/06",
"name": "California",
"place_type": "State"
}
],
"relatedThings": {},
"showForm": true,
"svSource": "UNRECOGNIZED",
"userMessage": "Could not recognize any topic from the query. See available topic categories for California."
}
17 changes: 14 additions & 3 deletions server/lib/explore/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from enum import Enum
from typing import Dict

from shared.lib import constants


class Params(str, Enum):
ENTITIES = 'entities'
Expand All @@ -31,10 +33,10 @@ class Params(str, Enum):
# Indicating it's a test query and the type of test, ex, "test=screenshot"
TEST = 'test'
I18N = 'i18n'
# Whether to use default place when no places are detected from the query.
USE_DEFAULT_PLACE = 'udp'
# The mode of query detection.
# - 'strict': detect and fulfill query with much higher specificity criteria.
# - 'strict': detect and fulfill more specific queries (without too many verbs),
# using a higher SV cosine score threshold (0.7), and without
# using a default place (if query doesn't specify places).
# Ex, if multiple verbs present, treat as action query and do not fulfill.
MODE = 'mode'

Expand All @@ -46,9 +48,18 @@ class DCNames(str, Enum):


class QueryMode(str, Enum):
# NOTE: This mode is incompatible with LLM detector
STRICT = 'strict'


# Get the SV score threshold for the given mode.
def sv_threshold(mode: str) -> bool:
if mode == QueryMode.STRICT:
return constants.SV_SCORE_HIGH_CONFIDENCE_THRESHOLD
else:
return constants.SV_SCORE_DEFAULT_THRESHOLD


def is_sdg(insight_ctx: Dict) -> bool:
return insight_ctx.get(
Params.DC.value) in [DCNames.SDG_DC.value, DCNames.SDG_MINI_DC.value]
6 changes: 4 additions & 2 deletions server/lib/nl/common/commentary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
from server.lib.nl.common.utterance import FulfillmentResult
from server.lib.nl.common.utterance import Utterance
from server.lib.nl.detection.utils import get_top_sv_score
from shared.lib.constants import SV_SCORE_HIGH_CONFIDENCE_THRESHOLD

#
# List of user messages!
#

# If the score is below this, then we report low confidence.
LOW_CONFIDENCE_SCORE_REPORT_THRESHOLD = 0.7
# If the score is below this, then we report low confidence
# (we reuse the threshold we use for determining something is "high confidence")
LOW_CONFIDENCE_SCORE_REPORT_THRESHOLD = SV_SCORE_HIGH_CONFIDENCE_THRESHOLD
LOW_CONFIDENCE_SCORE_MESSAGE = \
'Low confidence in understanding your query. Displaying the closest results.'

Expand Down
8 changes: 6 additions & 2 deletions server/lib/nl/detection/heuristic_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
from typing import Dict

from server.lib.explore import params
from server.lib.explore.params import QueryMode
import server.lib.nl.common.counters as ctr
from server.lib.nl.detection import heuristic_classifiers
Expand Down Expand Up @@ -43,17 +44,20 @@ def detect(place_detector_type: PlaceDetectorType, orig_query: str,

query = place_detection.query_without_place_substr

sv_threshold = params.sv_threshold(mode)
# Step 3: Identify the SV matched based on the query.
svs_scores_dict = dutils.empty_svs_score_dict()
try:
svs_scores_dict = variable.detect_svs(
query, index_type, query_detection_debug_logs["query_transformations"])
query, index_type, query_detection_debug_logs["query_transformations"],
sv_threshold)
except ValueError as e:
logging.info(e)
logging.info("Using an empty svs_scores_dict")

# Set the SVDetection.
sv_detection = dutils.create_sv_detection(query, svs_scores_dict)
sv_detection = dutils.create_sv_detection(query, svs_scores_dict,
sv_threshold)

# Step 4: find query classifiers.
classifications = [
Expand Down
2 changes: 1 addition & 1 deletion server/lib/nl/detection/llm_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def need_llm(heuristic: Detection, prev_uttr: Utterance,


def _has_no_sv(d: Detection, ctr: counters.Counters) -> bool:
return not dutils.filter_svs(d.svs_detected.single_sv, ctr)
return not dutils.filter_svs(d.svs_detected, ctr)


def _has_no_place(d: Detection) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions server/lib/nl/detection/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Dict, List, Optional

from shared.lib import detected_variables as dvars
from shared.lib.constants import SV_SCORE_DEFAULT_THRESHOLD


@dataclass
Expand Down Expand Up @@ -57,6 +58,8 @@ class SVDetection:
single_sv: dvars.VarCandidates
# Multi SV detection.
multi_sv: dvars.MultiVarCandidates
# Input SV Threshold
sv_threshold: float = SV_SCORE_DEFAULT_THRESHOLD


class RankingType(IntEnum):
Expand Down
Loading

0 comments on commit 0134730

Please sign in to comment.