From 0134730a843379b66e87bf9ca25550c68af410b1 Mon Sep 17 00:00:00 2001 From: Prashanth R Date: Tue, 7 Nov 2023 12:43:02 -0800 Subject: [PATCH] Suppress low-confidence SVs in `strict` mode (#3768) Currently, we have a low-confidence threshold (below which we print a message) of `0.7`, and a drop threshold (below which SVs are dropped) of `0.5` for SVs. The main change in this PR is, for`strict` mode, we increase the drop threshold to `0.7`. This is done by plumbing in the "threshold" value to the NL server, depending on the setting of `mode=strict`. Importantly, the `strict` mode is: * supported only for `api/data` API that nodejs uses * incompatible with LLM. Concretely, setting of `mode=strict` implies: `sv_threshold=0.7`, `detector=heuristic`, `use_default_place=False` Additionally, * Deprecate `udp=` param * Add support for `mode=strict` to NL app, since that exercises the `api/data` flow * Move tests from `explore_test.py` to `nl_test.py`, also for the same reason. --- .vscode/launch.json | 2 +- nl_server/embeddings.py | 11 +- nl_server/routes.py | 15 +- nl_server/tests/custom_embeddings_test.py | 5 +- server/config/subject_page_pb2.py | 366 ++++-------------- server/integration_tests/explore_test.py | 27 +- server/integration_tests/nl_test.py | 38 +- .../query_1}/chart_config.json | 0 .../query_2}/chart_config.json | 2 +- .../query_3/chart_config.json | 47 +++ .../query_1}/chart_config.json | 0 .../query_2/chart_config.json | 0 server/lib/explore/params.py | 17 +- server/lib/nl/common/commentary.py | 6 +- server/lib/nl/detection/heuristic_detector.py | 8 +- server/lib/nl/detection/llm_fallback.py | 2 +- server/lib/nl/detection/types.py | 3 + server/lib/nl/detection/utils.py | 23 +- server/lib/nl/detection/variable.py | 11 +- server/routes/nl/helpers.py | 18 +- server/services/datacommons.py | 4 +- shared/lib/constants.py | 8 + static/js/apps/explore/app.tsx | 19 +- static/js/apps/nl_interface/app_state.ts | 7 + static/js/constants/app/explore_constants.ts | 1 - .../constants/app/nl_interface_constants.ts | 5 + static/src/server.ts | 8 +- 27 files changed, 263 insertions(+), 390 deletions(-) rename server/integration_tests/test_data/{e2e_default_place/howtoearnmoneyonlinewithoutinvestment => strict_default_place/query_1}/chart_config.json (100%) rename server/integration_tests/test_data/{strict/query_1 => strict_default_place/query_2}/chart_config.json (91%) create mode 100644 server/integration_tests/test_data/strict_default_place/query_3/chart_config.json rename server/integration_tests/test_data/{e2e_default_place/whatdoesadietfordiabeteslooklike => strict_multi_verb/query_1}/chart_config.json (100%) rename server/integration_tests/test_data/{strict => strict_multi_verb}/query_2/chart_config.json (100%) diff --git a/.vscode/launch.json b/.vscode/launch.json index d21f754d22..003fc58f4b 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -51,7 +51,7 @@ "console": "internalConsole" }, { - "name": "Flask (Custom with NL)", + "name": "Flask (Custom DC with NL)", "type": "python", "justMyCode": false, "request": "launch", diff --git a/nl_server/embeddings.py b/nl_server/embeddings.py index f6490e8b2a..5b18f584d1 100644 --- a/nl_server/embeddings.py +++ b/nl_server/embeddings.py @@ -31,9 +31,6 @@ _HIGHEST_SCORE = 1.0 _INIT_SCORE = (_HIGHEST_SCORE + 0.1) -# Scores below this are ignored. -_SV_SCORE_THRESHOLD = 0.5 - _NUM_CANDIDATES_PER_NSPLIT = 3 # Number of matches to find within the SV index. @@ -140,6 +137,7 @@ def _search_embeddings(self, # def detect_svs(self, orig_query: str, + threshold: float = constants.SV_SCORE_DEFAULT_THRESHOLD, skip_multi_sv: bool = False) -> Dict[str, Union[Dict, List]]: # Remove all stop-words. query_monovar = utils.remove_stop_words(orig_query, @@ -153,7 +151,7 @@ def detect_svs(self, # Try to detect multiple SVs. Use the original query so that # the logic can rely on stop-words like `vs`, `and`, etc as hints # for SV delimiters. - result_multivar = self._detect_multiple_svs(orig_query) + result_multivar = self._detect_multiple_svs(orig_query, threshold) multi_sv = vars.multivar_candidates_to_dict(result_multivar) # TODO: Rename SV_to_Sentences for consistency. @@ -168,7 +166,8 @@ def detect_svs(self, # Detects one or more SVs from the query. # TODO: Fix the query upstream to ensure the punctuations aren't stripped. # - def _detect_multiple_svs(self, query: str) -> vars.MultiVarCandidates: + def _detect_multiple_svs(self, query: str, + threshold: float) -> vars.MultiVarCandidates: # # Prepare a combination of query-sets. # @@ -234,7 +233,7 @@ def _detect_multiple_svs(self, query: str) -> vars.MultiVarCandidates: total += score candidate.parts.append(part) - if lowest < _SV_SCORE_THRESHOLD: + if lowest < threshold: # A query-part's best SV did not cross our score threshold, # so drop this candidate. continue diff --git a/nl_server/routes.py b/nl_server/routes.py index c0f21526bf..8b59192ecd 100644 --- a/nl_server/routes.py +++ b/nl_server/routes.py @@ -22,6 +22,7 @@ from nl_server import config from nl_server import loader +from shared.lib.constants import SV_SCORE_DEFAULT_THRESHOLD bp = Blueprint('main', __name__, url_prefix='/') @@ -45,12 +46,24 @@ def search_sv(): idx = str(escape(request.args.get('idx', config.DEFAULT_INDEX_TYPE))) if not idx: idx = config.DEFAULT_INDEX_TYPE + + threshold = escape(request.args.get('threshold')) + if threshold: + try: + threshold = float(threshold) + except Exception: + logging.error(f'Found non-float threshold value: {threshold}') + threshold = SV_SCORE_DEFAULT_THRESHOLD + else: + threshold = SV_SCORE_DEFAULT_THRESHOLD + skip_multi_sv = False if request.args.get('skip_multi_sv'): skip_multi_sv = True + try: nl_embeddings = current_app.config[config.NL_EMBEDDINGS_KEY].get(idx) - return json.dumps(nl_embeddings.detect_svs(query, skip_multi_sv)) + return json.dumps(nl_embeddings.detect_svs(query, threshold, skip_multi_sv)) except Exception as e: logging.error(f'Embeddings-based SV detection failed with error: {e}') return json.dumps({ diff --git a/nl_server/tests/custom_embeddings_test.py b/nl_server/tests/custom_embeddings_test.py index c48b9d11b0..6ea9ae4409 100644 --- a/nl_server/tests/custom_embeddings_test.py +++ b/nl_server/tests/custom_embeddings_test.py @@ -21,9 +21,8 @@ from parameterized import parameterized from nl_server import config -from nl_server import embeddings from nl_server import embeddings_store as store -from nl_server import gcs +from shared.lib.constants import SV_SCORE_DEFAULT_THRESHOLD from shared.lib.gcs import TEMP_DIR _test_data = os.path.join(os.path.dirname(os.path.abspath(__file__)), @@ -97,7 +96,7 @@ def test_queries(self, dc, query, index, expected): if idx: got = idx.detect_svs(query) for i in range(len(got['SV'])): - if got['CosineScore'][i] >= embeddings._SV_SCORE_THRESHOLD: + if got['CosineScore'][i] >= SV_SCORE_DEFAULT_THRESHOLD: trimmed_svs.append(got['SV'][i]) if not expected: diff --git a/server/config/subject_page_pb2.py b/server/config/subject_page_pb2.py index d141f2fbd4..12c3837953 100644 --- a/server/config/subject_page_pb2.py +++ b/server/config/subject_page_pb2.py @@ -4,9 +4,8 @@ """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -16,239 +15,10 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12subject_page.proto\x12\x0b\x64\x61tacommons\"l\n\x0eSeverityFilter\x12\x0c\n\x04prop\x18\x01 \x01(\t\x12\x14\n\x0c\x64isplay_name\x18\x05 \x01(\t\x12\x0c\n\x04unit\x18\x02 \x01(\t\x12\x13\n\x0blower_limit\x18\x03 \x01(\x01\x12\x13\n\x0bupper_limit\x18\x04 \x01(\x01\"\x9b\x04\n\rEventTypeSpec\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x18\n\x10\x65vent_type_dcids\x18\x03 \x03(\t\x12\r\n\x05\x63olor\x18\x04 \x01(\t\x12<\n\x17\x64\x65\x66\x61ult_severity_filter\x18\x05 \x01(\x0b\x32\x1b.datacommons.SeverityFilter\x12[\n\x1aplace_type_severity_filter\x18\n \x03(\x0b\x32\x37.datacommons.EventTypeSpec.PlaceTypeSeverityFilterEntry\x12<\n\x0c\x64isplay_prop\x18\x06 \x03(\x0b\x32&.datacommons.EventTypeSpec.DisplayProp\x12\x15\n\rend_date_prop\x18\x07 \x03(\t\x12\x1d\n\x15polygon_geo_json_prop\x18\x08 \x01(\t\x12\x1a\n\x12path_geo_json_prop\x18\t \x01(\t\x1a[\n\x1cPlaceTypeSeverityFilterEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.datacommons.SeverityFilter:\x02\x38\x01\x1a?\n\x0b\x44isplayProp\x12\x0c\n\x04prop\x18\x01 \x01(\t\x12\x14\n\x0c\x64isplay_name\x18\x02 \x01(\t\x12\x0c\n\x04unit\x18\x03 \x01(\t\"\xe3\x03\n\x0cPageMetadata\x12\x10\n\x08topic_id\x18\x01 \x01(\t\x12\x12\n\ntopic_name\x18\x02 \x01(\t\x12\x12\n\nplace_dcid\x18\x03 \x03(\t\x12Q\n\x15\x63ontained_place_types\x18\x04 \x03(\x0b\x32\x32.datacommons.PageMetadata.ContainedPlaceTypesEntry\x12\x45\n\x0f\x65vent_type_spec\x18\x05 \x03(\x0b\x32,.datacommons.PageMetadata.EventTypeSpecEntry\x12\x39\n\x0bplace_group\x18\x06 \x03(\x0b\x32$.datacommons.PageMetadata.PlaceGroup\x1a:\n\x18\x43ontainedPlaceTypesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1aP\n\x12\x45ventTypeSpecEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.datacommons.EventTypeSpec:\x02\x38\x01\x1a\x36\n\nPlaceGroup\x12\x14\n\x0cparent_place\x18\x01 \x01(\t\x12\x12\n\nplace_type\x18\x02 \x01(\t\"\x9f\x01\n\x0bStatVarSpec\x12\x10\n\x08stat_var\x18\x01 \x01(\t\x12\r\n\x05\x64\x65nom\x18\x02 \x01(\t\x12\x0c\n\x04unit\x18\x03 \x01(\t\x12\x0f\n\x07scaling\x18\x04 \x01(\x01\x12\x0b\n\x03log\x18\x05 \x01(\x08\x12\x0c\n\x04name\x18\x06 \x01(\t\x12\x0c\n\x04\x64\x61te\x18\x07 \x01(\t\x12\x15\n\rno_per_capita\x18\x08 \x01(\x08\x12\x10\n\x08\x66\x61\x63\x65t_id\x18\t \x01(\t\"\xd0\x01\n\x0fRankingTileSpec\x12\x14\n\x0cshow_highest\x18\x01 \x01(\x08\x12\x13\n\x0bshow_lowest\x18\x02 \x01(\x08\x12\x16\n\x0e\x64iff_base_date\x18\x05 \x01(\t\x12\x15\n\rhighest_title\x18\x06 \x01(\t\x12\x14\n\x0clowest_title\x18\x07 \x01(\t\x12\x15\n\rranking_count\x18\n \x01(\x05\x12\x19\n\x11show_multi_column\x18\x0b \x01(\x08\x12\x1b\n\x13show_highest_lowest\x18\x0c \x01(\x08\"u\n\x18\x44isasterEventMapTileSpec\x12\x1c\n\x14point_event_type_key\x18\x01 \x03(\t\x12\x1e\n\x16polygon_event_type_key\x18\x02 \x03(\t\x12\x1b\n\x13path_event_type_key\x18\x03 \x03(\t\"9\n\x11HistogramTileSpec\x12\x16\n\x0e\x65vent_type_key\x18\x01 \x01(\t\x12\x0c\n\x04prop\x18\x02 \x01(\t\"\x9d\x01\n\x10TopEventTileSpec\x12\x16\n\x0e\x65vent_type_key\x18\x01 \x01(\t\x12\x14\n\x0c\x64isplay_prop\x18\x02 \x03(\t\x12\x17\n\x0fshow_start_date\x18\x03 \x01(\x08\x12\x15\n\rshow_end_date\x18\x04 \x01(\x08\x12\x14\n\x0creverse_sort\x18\x05 \x01(\x08\x12\x15\n\rranking_count\x18\x06 \x01(\x05\"\xbc\x01\n\x0fScatterTileSpec\x12\x1b\n\x13highlight_top_right\x18\x01 \x01(\x08\x12\x1a\n\x12highlight_top_left\x18\x02 \x01(\x08\x12\x1e\n\x16highlight_bottom_right\x18\x03 \x01(\x08\x12\x1d\n\x15highlight_bottom_left\x18\x04 \x01(\x08\x12\x19\n\x11show_place_labels\x18\x05 \x01(\x08\x12\x16\n\x0eshow_quadrants\x18\x06 \x01(\x08\"\xac\x03\n\x0b\x42\x61rTileSpec\x12\x19\n\x11x_label_link_root\x18\x01 \x01(\t\x12\x12\n\nbar_height\x18\x02 \x01(\x01\x12\x0e\n\x06\x63olors\x18\x03 \x03(\t\x12\x12\n\nhorizontal\x18\x04 \x01(\x08\x12\x12\n\nmax_places\x18\x05 \x01(\x05\x12\x15\n\rmax_variables\x18\x06 \x01(\x05\x12/\n\x04sort\x18\x07 \x01(\x0e\x32!.datacommons.BarTileSpec.SortType\x12\x0f\n\x07stacked\x18\x08 \x01(\x08\x12\x14\n\x0cuse_lollipop\x18\t \x01(\x08\x12\x15\n\ry_axis_margin\x18\n \x01(\x01\x12\x1b\n\x13variable_name_regex\x18\x0b \x01(\t\x12\x1d\n\x15\x64\x65\x66\x61ult_variable_name\x18\x0c \x01(\t\"t\n\x08SortType\x12\x14\n\x10TYPE_UNSPECIFIED\x10\x00\x12\r\n\tASCENDING\x10\x01\x12\x0e\n\nDESCENDING\x10\x02\x12\x18\n\x14\x41SCENDING_POPULATION\x10\x03\x12\x19\n\x15\x44\x45SCENDING_POPULATION\x10\x04\"s\n\rGaugeTileSpec\x12/\n\x05range\x18\x01 \x01(\x0b\x32 .datacommons.GaugeTileSpec.Range\x12\x0e\n\x06\x63olors\x18\x02 \x03(\t\x1a!\n\x05Range\x12\x0b\n\x03min\x18\x01 \x01(\x01\x12\x0b\n\x03max\x18\x02 \x01(\x01\",\n\rDonutTileSpec\x12\x0e\n\x06\x63olors\x18\x01 \x03(\t\x12\x0b\n\x03pie\x18\x02 \x01(\x08\"\xdb\x01\n\x0cLineTileSpec\x12\x0e\n\x06\x63olors\x18\x01 \x03(\t\x12:\n\ttimeScale\x18\x02 \x01(\x0e\x32\'.datacommons.LineTileSpec.TimeScaleType\x12\x1b\n\x13variable_name_regex\x18\x03 \x01(\t\x12\x1d\n\x15\x64\x65\x66\x61ult_variable_name\x18\x04 \x01(\t\"C\n\rTimeScaleType\x12\x14\n\x10TYPE_UNSPECIFIED\x10\x00\x12\t\n\x05MONTH\x10\x01\x12\x08\n\x04YEAR\x10\x02\x12\x07\n\x03\x44\x41Y\x10\x03\"4\n\x0bMapTileSpec\x12\x0e\n\x06\x63olors\x18\x02 \x03(\t\x12\x15\n\rgeo_json_prop\x18\x03 \x01(\t\"\xa0\x08\n\x04Tile\x12\r\n\x05title\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12(\n\x04type\x18\x03 \x01(\x0e\x32\x1a.datacommons.Tile.TileType\x12\x14\n\x0cstat_var_key\x18\x04 \x03(\t\x12\x19\n\x11\x63omparison_places\x18\x07 \x03(\t\x12\x1b\n\x13place_dcid_override\x18\x0b \x01(\t\x12\x13\n\x0bhide_footer\x18\x11 \x01(\x08\x12\x10\n\x08subtitle\x18\x12 \x01(\t\x12\x17\n\x0fplace_name_prop\x18\x13 \x01(\t\x12\x39\n\x11ranking_tile_spec\x18\x05 \x01(\x0b\x32\x1c.datacommons.RankingTileSpecH\x00\x12M\n\x1c\x64isaster_event_map_tile_spec\x18\x06 \x01(\x0b\x32%.datacommons.DisasterEventMapTileSpecH\x00\x12<\n\x13top_event_tile_spec\x18\x08 \x01(\x0b\x32\x1d.datacommons.TopEventTileSpecH\x00\x12\x39\n\x11scatter_tile_spec\x18\t \x01(\x0b\x32\x1c.datacommons.ScatterTileSpecH\x00\x12=\n\x13histogram_tile_spec\x18\n \x01(\x0b\x32\x1e.datacommons.HistogramTileSpecH\x00\x12\x31\n\rbar_tile_spec\x18\x0c \x01(\x0b\x32\x18.datacommons.BarTileSpecH\x00\x12\x35\n\x0fgauge_tile_spec\x18\r \x01(\x0b\x32\x1a.datacommons.GaugeTileSpecH\x00\x12\x35\n\x0f\x64onut_tile_spec\x18\x0e \x01(\x0b\x32\x1a.datacommons.DonutTileSpecH\x00\x12\x33\n\x0eline_tile_spec\x18\x0f \x01(\x0b\x32\x19.datacommons.LineTileSpecH\x00\x12\x31\n\rmap_tile_spec\x18\x10 \x01(\x0b\x32\x18.datacommons.MapTileSpecH\x00\"\xde\x01\n\x08TileType\x12\r\n\tTYPE_NONE\x10\x00\x12\x08\n\x04LINE\x10\x01\x12\x07\n\x03\x42\x41R\x10\x02\x12\x07\n\x03MAP\x10\x03\x12\x0b\n\x07SCATTER\x10\x04\x12\r\n\tBIVARIATE\x10\x05\x12\x0b\n\x07RANKING\x10\x06\x12\r\n\tHIGHLIGHT\x10\x07\x12\x0f\n\x0b\x44\x45SCRIPTION\x10\x08\x12\t\n\x05GAUGE\x10\r\x12\t\n\x05\x44ONUT\x10\x0e\x12\r\n\tHISTOGRAM\x10\n\x12\x12\n\x0ePLACE_OVERVIEW\x10\x0b\x12\r\n\tTOP_EVENT\x10\x0c\x12\x16\n\x12\x44ISASTER_EVENT_MAP\x10\tB\x10\n\x0etile_type_spec\"\xcf\x01\n\x11\x44isasterBlockSpec\x12>\n\ndate_range\x18\x01 \x01(\x0e\x32(.datacommons.DisasterBlockSpec.DateRangeH\x00\x12\x0e\n\x04\x64\x61te\x18\x02 \x01(\tH\x00\"Z\n\tDateRange\x12\r\n\tTYPE_NONE\x10\x00\x12\x0f\n\x0bTHIRTY_DAYS\x10\x01\x12\x0e\n\nSIX_MONTHS\x10\x02\x12\x0c\n\x08ONE_YEAR\x10\x03\x12\x0f\n\x0bTHREE_YEARS\x10\x04\x42\x0e\n\x0c\x64\x65\x66\x61ult_date\"\xec\x02\n\x05\x42lock\x12\r\n\x05title\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x10\n\x08\x66ootnote\x18\x05 \x01(\t\x12*\n\x07\x63olumns\x18\x03 \x03(\x0b\x32\x19.datacommons.Block.Column\x12*\n\x04type\x18\x04 \x01(\x0e\x32\x1c.datacommons.Block.BlockType\x12\r\n\x05\x64\x65nom\x18\x06 \x01(\t\x12\x18\n\x10start_with_denom\x18\x07 \x01(\x08\x12=\n\x13\x64isaster_block_spec\x18\x08 \x01(\x0b\x32\x1e.datacommons.DisasterBlockSpecH\x00\x1a*\n\x06\x43olumn\x12 \n\x05tiles\x18\x01 \x03(\x0b\x32\x11.datacommons.Tile\".\n\tBlockType\x12\r\n\tTYPE_NONE\x10\x00\x12\x12\n\x0e\x44ISASTER_EVENT\x10\x01\x42\x11\n\x0f\x62lock_type_spec\"\xed\x01\n\x08\x43\x61tegory\x12\r\n\x05title\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12=\n\rstat_var_spec\x18\x04 \x03(\x0b\x32&.datacommons.Category.StatVarSpecEntry\x12\"\n\x06\x62locks\x18\x03 \x03(\x0b\x32\x12.datacommons.Block\x12\x0c\n\x04\x64\x63id\x18\x05 \x01(\t\x1aL\n\x10StatVarSpecEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.datacommons.StatVarSpec:\x02\x38\x01\"q\n\x11SubjectPageConfig\x12+\n\x08metadata\x18\x01 \x01(\x0b\x32\x19.datacommons.PageMetadata\x12)\n\ncategories\x18\x02 \x03(\x0b\x32\x15.datacommons.CategoryJ\x04\x08\x03\x10\x04\x62\x06proto3') - - -_SEVERITYFILTER = DESCRIPTOR.message_types_by_name['SeverityFilter'] -_EVENTTYPESPEC = DESCRIPTOR.message_types_by_name['EventTypeSpec'] -_EVENTTYPESPEC_PLACETYPESEVERITYFILTERENTRY = _EVENTTYPESPEC.nested_types_by_name['PlaceTypeSeverityFilterEntry'] -_EVENTTYPESPEC_DISPLAYPROP = _EVENTTYPESPEC.nested_types_by_name['DisplayProp'] -_PAGEMETADATA = DESCRIPTOR.message_types_by_name['PageMetadata'] -_PAGEMETADATA_CONTAINEDPLACETYPESENTRY = _PAGEMETADATA.nested_types_by_name['ContainedPlaceTypesEntry'] -_PAGEMETADATA_EVENTTYPESPECENTRY = _PAGEMETADATA.nested_types_by_name['EventTypeSpecEntry'] -_PAGEMETADATA_PLACEGROUP = _PAGEMETADATA.nested_types_by_name['PlaceGroup'] -_STATVARSPEC = DESCRIPTOR.message_types_by_name['StatVarSpec'] -_RANKINGTILESPEC = DESCRIPTOR.message_types_by_name['RankingTileSpec'] -_DISASTEREVENTMAPTILESPEC = DESCRIPTOR.message_types_by_name['DisasterEventMapTileSpec'] -_HISTOGRAMTILESPEC = DESCRIPTOR.message_types_by_name['HistogramTileSpec'] -_TOPEVENTTILESPEC = DESCRIPTOR.message_types_by_name['TopEventTileSpec'] -_SCATTERTILESPEC = DESCRIPTOR.message_types_by_name['ScatterTileSpec'] -_BARTILESPEC = DESCRIPTOR.message_types_by_name['BarTileSpec'] -_GAUGETILESPEC = DESCRIPTOR.message_types_by_name['GaugeTileSpec'] -_GAUGETILESPEC_RANGE = _GAUGETILESPEC.nested_types_by_name['Range'] -_DONUTTILESPEC = DESCRIPTOR.message_types_by_name['DonutTileSpec'] -_LINETILESPEC = DESCRIPTOR.message_types_by_name['LineTileSpec'] -_MAPTILESPEC = DESCRIPTOR.message_types_by_name['MapTileSpec'] -_TILE = DESCRIPTOR.message_types_by_name['Tile'] -_DISASTERBLOCKSPEC = DESCRIPTOR.message_types_by_name['DisasterBlockSpec'] -_BLOCK = DESCRIPTOR.message_types_by_name['Block'] -_BLOCK_COLUMN = _BLOCK.nested_types_by_name['Column'] -_CATEGORY = DESCRIPTOR.message_types_by_name['Category'] -_CATEGORY_STATVARSPECENTRY = _CATEGORY.nested_types_by_name['StatVarSpecEntry'] -_SUBJECTPAGECONFIG = DESCRIPTOR.message_types_by_name['SubjectPageConfig'] -_BARTILESPEC_SORTTYPE = _BARTILESPEC.enum_types_by_name['SortType'] -_LINETILESPEC_TIMESCALETYPE = _LINETILESPEC.enum_types_by_name['TimeScaleType'] -_TILE_TILETYPE = _TILE.enum_types_by_name['TileType'] -_DISASTERBLOCKSPEC_DATERANGE = _DISASTERBLOCKSPEC.enum_types_by_name['DateRange'] -_BLOCK_BLOCKTYPE = _BLOCK.enum_types_by_name['BlockType'] -SeverityFilter = _reflection.GeneratedProtocolMessageType('SeverityFilter', (_message.Message,), { - 'DESCRIPTOR' : _SEVERITYFILTER, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.SeverityFilter) - }) -_sym_db.RegisterMessage(SeverityFilter) - -EventTypeSpec = _reflection.GeneratedProtocolMessageType('EventTypeSpec', (_message.Message,), { - - 'PlaceTypeSeverityFilterEntry' : _reflection.GeneratedProtocolMessageType('PlaceTypeSeverityFilterEntry', (_message.Message,), { - 'DESCRIPTOR' : _EVENTTYPESPEC_PLACETYPESEVERITYFILTERENTRY, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.EventTypeSpec.PlaceTypeSeverityFilterEntry) - }) - , - - 'DisplayProp' : _reflection.GeneratedProtocolMessageType('DisplayProp', (_message.Message,), { - 'DESCRIPTOR' : _EVENTTYPESPEC_DISPLAYPROP, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.EventTypeSpec.DisplayProp) - }) - , - 'DESCRIPTOR' : _EVENTTYPESPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.EventTypeSpec) - }) -_sym_db.RegisterMessage(EventTypeSpec) -_sym_db.RegisterMessage(EventTypeSpec.PlaceTypeSeverityFilterEntry) -_sym_db.RegisterMessage(EventTypeSpec.DisplayProp) - -PageMetadata = _reflection.GeneratedProtocolMessageType('PageMetadata', (_message.Message,), { - - 'ContainedPlaceTypesEntry' : _reflection.GeneratedProtocolMessageType('ContainedPlaceTypesEntry', (_message.Message,), { - 'DESCRIPTOR' : _PAGEMETADATA_CONTAINEDPLACETYPESENTRY, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.PageMetadata.ContainedPlaceTypesEntry) - }) - , - - 'EventTypeSpecEntry' : _reflection.GeneratedProtocolMessageType('EventTypeSpecEntry', (_message.Message,), { - 'DESCRIPTOR' : _PAGEMETADATA_EVENTTYPESPECENTRY, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.PageMetadata.EventTypeSpecEntry) - }) - , - - 'PlaceGroup' : _reflection.GeneratedProtocolMessageType('PlaceGroup', (_message.Message,), { - 'DESCRIPTOR' : _PAGEMETADATA_PLACEGROUP, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.PageMetadata.PlaceGroup) - }) - , - 'DESCRIPTOR' : _PAGEMETADATA, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.PageMetadata) - }) -_sym_db.RegisterMessage(PageMetadata) -_sym_db.RegisterMessage(PageMetadata.ContainedPlaceTypesEntry) -_sym_db.RegisterMessage(PageMetadata.EventTypeSpecEntry) -_sym_db.RegisterMessage(PageMetadata.PlaceGroup) - -StatVarSpec = _reflection.GeneratedProtocolMessageType('StatVarSpec', (_message.Message,), { - 'DESCRIPTOR' : _STATVARSPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.StatVarSpec) - }) -_sym_db.RegisterMessage(StatVarSpec) - -RankingTileSpec = _reflection.GeneratedProtocolMessageType('RankingTileSpec', (_message.Message,), { - 'DESCRIPTOR' : _RANKINGTILESPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.RankingTileSpec) - }) -_sym_db.RegisterMessage(RankingTileSpec) - -DisasterEventMapTileSpec = _reflection.GeneratedProtocolMessageType('DisasterEventMapTileSpec', (_message.Message,), { - 'DESCRIPTOR' : _DISASTEREVENTMAPTILESPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.DisasterEventMapTileSpec) - }) -_sym_db.RegisterMessage(DisasterEventMapTileSpec) - -HistogramTileSpec = _reflection.GeneratedProtocolMessageType('HistogramTileSpec', (_message.Message,), { - 'DESCRIPTOR' : _HISTOGRAMTILESPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.HistogramTileSpec) - }) -_sym_db.RegisterMessage(HistogramTileSpec) - -TopEventTileSpec = _reflection.GeneratedProtocolMessageType('TopEventTileSpec', (_message.Message,), { - 'DESCRIPTOR' : _TOPEVENTTILESPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.TopEventTileSpec) - }) -_sym_db.RegisterMessage(TopEventTileSpec) - -ScatterTileSpec = _reflection.GeneratedProtocolMessageType('ScatterTileSpec', (_message.Message,), { - 'DESCRIPTOR' : _SCATTERTILESPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.ScatterTileSpec) - }) -_sym_db.RegisterMessage(ScatterTileSpec) - -BarTileSpec = _reflection.GeneratedProtocolMessageType('BarTileSpec', (_message.Message,), { - 'DESCRIPTOR' : _BARTILESPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.BarTileSpec) - }) -_sym_db.RegisterMessage(BarTileSpec) - -GaugeTileSpec = _reflection.GeneratedProtocolMessageType('GaugeTileSpec', (_message.Message,), { - - 'Range' : _reflection.GeneratedProtocolMessageType('Range', (_message.Message,), { - 'DESCRIPTOR' : _GAUGETILESPEC_RANGE, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.GaugeTileSpec.Range) - }) - , - 'DESCRIPTOR' : _GAUGETILESPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.GaugeTileSpec) - }) -_sym_db.RegisterMessage(GaugeTileSpec) -_sym_db.RegisterMessage(GaugeTileSpec.Range) - -DonutTileSpec = _reflection.GeneratedProtocolMessageType('DonutTileSpec', (_message.Message,), { - 'DESCRIPTOR' : _DONUTTILESPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.DonutTileSpec) - }) -_sym_db.RegisterMessage(DonutTileSpec) - -LineTileSpec = _reflection.GeneratedProtocolMessageType('LineTileSpec', (_message.Message,), { - 'DESCRIPTOR' : _LINETILESPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.LineTileSpec) - }) -_sym_db.RegisterMessage(LineTileSpec) - -MapTileSpec = _reflection.GeneratedProtocolMessageType('MapTileSpec', (_message.Message,), { - 'DESCRIPTOR' : _MAPTILESPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.MapTileSpec) - }) -_sym_db.RegisterMessage(MapTileSpec) - -Tile = _reflection.GeneratedProtocolMessageType('Tile', (_message.Message,), { - 'DESCRIPTOR' : _TILE, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.Tile) - }) -_sym_db.RegisterMessage(Tile) - -DisasterBlockSpec = _reflection.GeneratedProtocolMessageType('DisasterBlockSpec', (_message.Message,), { - 'DESCRIPTOR' : _DISASTERBLOCKSPEC, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.DisasterBlockSpec) - }) -_sym_db.RegisterMessage(DisasterBlockSpec) - -Block = _reflection.GeneratedProtocolMessageType('Block', (_message.Message,), { - - 'Column' : _reflection.GeneratedProtocolMessageType('Column', (_message.Message,), { - 'DESCRIPTOR' : _BLOCK_COLUMN, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.Block.Column) - }) - , - 'DESCRIPTOR' : _BLOCK, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.Block) - }) -_sym_db.RegisterMessage(Block) -_sym_db.RegisterMessage(Block.Column) - -Category = _reflection.GeneratedProtocolMessageType('Category', (_message.Message,), { - - 'StatVarSpecEntry' : _reflection.GeneratedProtocolMessageType('StatVarSpecEntry', (_message.Message,), { - 'DESCRIPTOR' : _CATEGORY_STATVARSPECENTRY, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.Category.StatVarSpecEntry) - }) - , - 'DESCRIPTOR' : _CATEGORY, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.Category) - }) -_sym_db.RegisterMessage(Category) -_sym_db.RegisterMessage(Category.StatVarSpecEntry) - -SubjectPageConfig = _reflection.GeneratedProtocolMessageType('SubjectPageConfig', (_message.Message,), { - 'DESCRIPTOR' : _SUBJECTPAGECONFIG, - '__module__' : 'subject_page_pb2' - # @@protoc_insertion_point(class_scope:datacommons.SubjectPageConfig) - }) -_sym_db.RegisterMessage(SubjectPageConfig) - +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'subject_page_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None _EVENTTYPESPEC_PLACETYPESEVERITYFILTERENTRY._options = None _EVENTTYPESPEC_PLACETYPESEVERITYFILTERENTRY._serialized_options = b'8\001' @@ -258,68 +28,68 @@ _PAGEMETADATA_EVENTTYPESPECENTRY._serialized_options = b'8\001' _CATEGORY_STATVARSPECENTRY._options = None _CATEGORY_STATVARSPECENTRY._serialized_options = b'8\001' - _SEVERITYFILTER._serialized_start=35 - _SEVERITYFILTER._serialized_end=143 - _EVENTTYPESPEC._serialized_start=146 - _EVENTTYPESPEC._serialized_end=685 - _EVENTTYPESPEC_PLACETYPESEVERITYFILTERENTRY._serialized_start=529 - _EVENTTYPESPEC_PLACETYPESEVERITYFILTERENTRY._serialized_end=620 - _EVENTTYPESPEC_DISPLAYPROP._serialized_start=622 - _EVENTTYPESPEC_DISPLAYPROP._serialized_end=685 - _PAGEMETADATA._serialized_start=688 - _PAGEMETADATA._serialized_end=1171 - _PAGEMETADATA_CONTAINEDPLACETYPESENTRY._serialized_start=975 - _PAGEMETADATA_CONTAINEDPLACETYPESENTRY._serialized_end=1033 - _PAGEMETADATA_EVENTTYPESPECENTRY._serialized_start=1035 - _PAGEMETADATA_EVENTTYPESPECENTRY._serialized_end=1115 - _PAGEMETADATA_PLACEGROUP._serialized_start=1117 - _PAGEMETADATA_PLACEGROUP._serialized_end=1171 - _STATVARSPEC._serialized_start=1174 - _STATVARSPEC._serialized_end=1333 - _RANKINGTILESPEC._serialized_start=1336 - _RANKINGTILESPEC._serialized_end=1544 - _DISASTEREVENTMAPTILESPEC._serialized_start=1546 - _DISASTEREVENTMAPTILESPEC._serialized_end=1663 - _HISTOGRAMTILESPEC._serialized_start=1665 - _HISTOGRAMTILESPEC._serialized_end=1722 - _TOPEVENTTILESPEC._serialized_start=1725 - _TOPEVENTTILESPEC._serialized_end=1882 - _SCATTERTILESPEC._serialized_start=1885 - _SCATTERTILESPEC._serialized_end=2073 - _BARTILESPEC._serialized_start=2076 - _BARTILESPEC._serialized_end=2504 - _BARTILESPEC_SORTTYPE._serialized_start=2388 - _BARTILESPEC_SORTTYPE._serialized_end=2504 - _GAUGETILESPEC._serialized_start=2506 - _GAUGETILESPEC._serialized_end=2621 - _GAUGETILESPEC_RANGE._serialized_start=2588 - _GAUGETILESPEC_RANGE._serialized_end=2621 - _DONUTTILESPEC._serialized_start=2623 - _DONUTTILESPEC._serialized_end=2667 - _LINETILESPEC._serialized_start=2670 - _LINETILESPEC._serialized_end=2889 - _LINETILESPEC_TIMESCALETYPE._serialized_start=2822 - _LINETILESPEC_TIMESCALETYPE._serialized_end=2889 - _MAPTILESPEC._serialized_start=2891 - _MAPTILESPEC._serialized_end=2943 - _TILE._serialized_start=2946 - _TILE._serialized_end=4002 - _TILE_TILETYPE._serialized_start=3762 - _TILE_TILETYPE._serialized_end=3984 - _DISASTERBLOCKSPEC._serialized_start=4005 - _DISASTERBLOCKSPEC._serialized_end=4212 - _DISASTERBLOCKSPEC_DATERANGE._serialized_start=4106 - _DISASTERBLOCKSPEC_DATERANGE._serialized_end=4196 - _BLOCK._serialized_start=4215 - _BLOCK._serialized_end=4579 - _BLOCK_COLUMN._serialized_start=4470 - _BLOCK_COLUMN._serialized_end=4512 - _BLOCK_BLOCKTYPE._serialized_start=4514 - _BLOCK_BLOCKTYPE._serialized_end=4560 - _CATEGORY._serialized_start=4582 - _CATEGORY._serialized_end=4819 - _CATEGORY_STATVARSPECENTRY._serialized_start=4743 - _CATEGORY_STATVARSPECENTRY._serialized_end=4819 - _SUBJECTPAGECONFIG._serialized_start=4821 - _SUBJECTPAGECONFIG._serialized_end=4934 + _globals['_SEVERITYFILTER']._serialized_start=35 + _globals['_SEVERITYFILTER']._serialized_end=143 + _globals['_EVENTTYPESPEC']._serialized_start=146 + _globals['_EVENTTYPESPEC']._serialized_end=685 + _globals['_EVENTTYPESPEC_PLACETYPESEVERITYFILTERENTRY']._serialized_start=529 + _globals['_EVENTTYPESPEC_PLACETYPESEVERITYFILTERENTRY']._serialized_end=620 + _globals['_EVENTTYPESPEC_DISPLAYPROP']._serialized_start=622 + _globals['_EVENTTYPESPEC_DISPLAYPROP']._serialized_end=685 + _globals['_PAGEMETADATA']._serialized_start=688 + _globals['_PAGEMETADATA']._serialized_end=1171 + _globals['_PAGEMETADATA_CONTAINEDPLACETYPESENTRY']._serialized_start=975 + _globals['_PAGEMETADATA_CONTAINEDPLACETYPESENTRY']._serialized_end=1033 + _globals['_PAGEMETADATA_EVENTTYPESPECENTRY']._serialized_start=1035 + _globals['_PAGEMETADATA_EVENTTYPESPECENTRY']._serialized_end=1115 + _globals['_PAGEMETADATA_PLACEGROUP']._serialized_start=1117 + _globals['_PAGEMETADATA_PLACEGROUP']._serialized_end=1171 + _globals['_STATVARSPEC']._serialized_start=1174 + _globals['_STATVARSPEC']._serialized_end=1333 + _globals['_RANKINGTILESPEC']._serialized_start=1336 + _globals['_RANKINGTILESPEC']._serialized_end=1544 + _globals['_DISASTEREVENTMAPTILESPEC']._serialized_start=1546 + _globals['_DISASTEREVENTMAPTILESPEC']._serialized_end=1663 + _globals['_HISTOGRAMTILESPEC']._serialized_start=1665 + _globals['_HISTOGRAMTILESPEC']._serialized_end=1722 + _globals['_TOPEVENTTILESPEC']._serialized_start=1725 + _globals['_TOPEVENTTILESPEC']._serialized_end=1882 + _globals['_SCATTERTILESPEC']._serialized_start=1885 + _globals['_SCATTERTILESPEC']._serialized_end=2073 + _globals['_BARTILESPEC']._serialized_start=2076 + _globals['_BARTILESPEC']._serialized_end=2504 + _globals['_BARTILESPEC_SORTTYPE']._serialized_start=2388 + _globals['_BARTILESPEC_SORTTYPE']._serialized_end=2504 + _globals['_GAUGETILESPEC']._serialized_start=2506 + _globals['_GAUGETILESPEC']._serialized_end=2621 + _globals['_GAUGETILESPEC_RANGE']._serialized_start=2588 + _globals['_GAUGETILESPEC_RANGE']._serialized_end=2621 + _globals['_DONUTTILESPEC']._serialized_start=2623 + _globals['_DONUTTILESPEC']._serialized_end=2667 + _globals['_LINETILESPEC']._serialized_start=2670 + _globals['_LINETILESPEC']._serialized_end=2889 + _globals['_LINETILESPEC_TIMESCALETYPE']._serialized_start=2822 + _globals['_LINETILESPEC_TIMESCALETYPE']._serialized_end=2889 + _globals['_MAPTILESPEC']._serialized_start=2891 + _globals['_MAPTILESPEC']._serialized_end=2943 + _globals['_TILE']._serialized_start=2946 + _globals['_TILE']._serialized_end=4002 + _globals['_TILE_TILETYPE']._serialized_start=3762 + _globals['_TILE_TILETYPE']._serialized_end=3984 + _globals['_DISASTERBLOCKSPEC']._serialized_start=4005 + _globals['_DISASTERBLOCKSPEC']._serialized_end=4212 + _globals['_DISASTERBLOCKSPEC_DATERANGE']._serialized_start=4106 + _globals['_DISASTERBLOCKSPEC_DATERANGE']._serialized_end=4196 + _globals['_BLOCK']._serialized_start=4215 + _globals['_BLOCK']._serialized_end=4579 + _globals['_BLOCK_COLUMN']._serialized_start=4470 + _globals['_BLOCK_COLUMN']._serialized_end=4512 + _globals['_BLOCK_BLOCKTYPE']._serialized_start=4514 + _globals['_BLOCK_BLOCKTYPE']._serialized_end=4560 + _globals['_CATEGORY']._serialized_start=4582 + _globals['_CATEGORY']._serialized_end=4819 + _globals['_CATEGORY_STATVARSPECENTRY']._serialized_start=4743 + _globals['_CATEGORY_STATVARSPECENTRY']._serialized_end=4819 + _globals['_SUBJECTPAGECONFIG']._serialized_start=4821 + _globals['_SUBJECTPAGECONFIG']._serialized_end=4934 # @@protoc_insertion_point(module_scope) diff --git a/server/integration_tests/explore_test.py b/server/integration_tests/explore_test.py index a732ae4bb2..cada4c15af 100644 --- a/server/integration_tests/explore_test.py +++ b/server/integration_tests/explore_test.py @@ -30,17 +30,10 @@ class ExploreTest(NLWebServerTestCase): - def run_fulfillment(self, - test_dir, - req_json, - failure='', - test='', - i18n='', - udp=''): - resp = requests.post( - self.get_server_url() + - f'/api/explore/fulfill?test={test}&i18n={i18n}&udp={udp}', - json=req_json).json() + def run_fulfillment(self, test_dir, req_json, failure='', test='', i18n=''): + resp = requests.post(self.get_server_url() + + f'/api/explore/fulfill?test={test}&i18n={i18n}', + json=req_json).json() self.handle_response(json.dumps(req_json), resp, test_dir, '', failure) def run_detection(self, @@ -71,13 +64,12 @@ def run_detect_and_fulfill(self, dc='', failure='', test='', - i18n='', - udp=''): + i18n=''): ctx = {} for (index, q) in enumerate(queries): resp = requests.post( self.get_server_url() + - f'/api/explore/detect-and-fulfill?q={q}&test={test}&i18n={i18n}&udp={udp}', + f'/api/explore/detect-and-fulfill?q={q}&test={test}&i18n={i18n}', json={ 'contextHistory': ctx, 'dc': dc, @@ -410,10 +402,3 @@ def test_e2e_fallbacks(self): # to the place (SC county) to its state (CA). 'auto thefts in tracts of santa clara county' ]) - - def test_e2e_default_place(self): - self.run_detect_and_fulfill('e2e_default_place', [ - 'what does a diet for diabetes look like?', - 'how to earn money online without investment' - ], - udp='false') diff --git a/server/integration_tests/nl_test.py b/server/integration_tests/nl_test.py index bcac0012ef..035f018943 100644 --- a/server/integration_tests/nl_test.py +++ b/server/integration_tests/nl_test.py @@ -275,9 +275,35 @@ def test_sdg(self): ['how many wise asses live in sunnyvale?'], failure='could not complete') - def test_strict(self): - self.run_sequence('strict', [ - 'how do i build and construct a house and sell it in california with low income', - 'tell me asian california population with low income' - ], - mode='strict') + def test_strict_multi_verb(self): + self.run_sequence( + 'strict_multi_verb', + [ + # This query should return empty results in strict mode. + 'how do i build and construct a house and sell it in california with low income', + # This query should be fine. + 'tell me asian california population with low income', + ], + mode='strict', + expected_detectors=[ + 'Heuristic Based', + 'Heuristic Based', + ]) + + def test_strict_default_place(self): + self.run_sequence( + 'strict_default_place', + [ + # These queries do not have a default place, so should fail. + 'what does a diet for diabetes look like?', + 'how to earn money online without investment', + # This query should return empty result because we don't + # return low-confidence results. + 'number of headless drivers in california', + ], + mode='strict', + expected_detectors=[ + 'Heuristic Based', + 'Heuristic Based', + 'Heuristic Based', + ]) diff --git a/server/integration_tests/test_data/e2e_default_place/howtoearnmoneyonlinewithoutinvestment/chart_config.json b/server/integration_tests/test_data/strict_default_place/query_1/chart_config.json similarity index 100% rename from server/integration_tests/test_data/e2e_default_place/howtoearnmoneyonlinewithoutinvestment/chart_config.json rename to server/integration_tests/test_data/strict_default_place/query_1/chart_config.json diff --git a/server/integration_tests/test_data/strict/query_1/chart_config.json b/server/integration_tests/test_data/strict_default_place/query_2/chart_config.json similarity index 91% rename from server/integration_tests/test_data/strict/query_1/chart_config.json rename to server/integration_tests/test_data/strict_default_place/query_2/chart_config.json index e5fff888fa..0ae3bcb218 100644 --- a/server/integration_tests/test_data/strict/query_1/chart_config.json +++ b/server/integration_tests/test_data/strict_default_place/query_2/chart_config.json @@ -18,6 +18,6 @@ } ], "relatedThings": {}, - "svSource": "CURRENT_QUERY", + "svSource": "UNKNOWN", "userMessage": "" } \ No newline at end of file diff --git a/server/integration_tests/test_data/strict_default_place/query_3/chart_config.json b/server/integration_tests/test_data/strict_default_place/query_3/chart_config.json new file mode 100644 index 0000000000..a845761c70 --- /dev/null +++ b/server/integration_tests/test_data/strict_default_place/query_3/chart_config.json @@ -0,0 +1,47 @@ +{ + "config": { + "categories": [ + { + "blocks": [ + { + "columns": [ + { + "tiles": [ + { + "type": "PLACE_OVERVIEW" + } + ] + } + ] + } + ] + } + ], + "metadata": { + "placeDcid": [ + "geoId/06" + ] + } + }, + "context": {}, + "debug": {}, + "pastSourceContext": "", + "place": { + "dcid": "geoId/06", + "name": "California", + "place_type": "State" + }, + "placeFallback": {}, + "placeSource": "CURRENT_QUERY", + "places": [ + { + "dcid": "geoId/06", + "name": "California", + "place_type": "State" + } + ], + "relatedThings": {}, + "showForm": true, + "svSource": "UNRECOGNIZED", + "userMessage": "Could not recognize any topic from the query. See available topic categories for California." +} \ No newline at end of file diff --git a/server/integration_tests/test_data/e2e_default_place/whatdoesadietfordiabeteslooklike/chart_config.json b/server/integration_tests/test_data/strict_multi_verb/query_1/chart_config.json similarity index 100% rename from server/integration_tests/test_data/e2e_default_place/whatdoesadietfordiabeteslooklike/chart_config.json rename to server/integration_tests/test_data/strict_multi_verb/query_1/chart_config.json diff --git a/server/integration_tests/test_data/strict/query_2/chart_config.json b/server/integration_tests/test_data/strict_multi_verb/query_2/chart_config.json similarity index 100% rename from server/integration_tests/test_data/strict/query_2/chart_config.json rename to server/integration_tests/test_data/strict_multi_verb/query_2/chart_config.json diff --git a/server/lib/explore/params.py b/server/lib/explore/params.py index 9ab0c01bb5..990383df6c 100644 --- a/server/lib/explore/params.py +++ b/server/lib/explore/params.py @@ -15,6 +15,8 @@ from enum import Enum from typing import Dict +from shared.lib import constants + class Params(str, Enum): ENTITIES = 'entities' @@ -31,10 +33,10 @@ class Params(str, Enum): # Indicating it's a test query and the type of test, ex, "test=screenshot" TEST = 'test' I18N = 'i18n' - # Whether to use default place when no places are detected from the query. - USE_DEFAULT_PLACE = 'udp' # The mode of query detection. - # - 'strict': detect and fulfill query with much higher specificity criteria. + # - 'strict': detect and fulfill more specific queries (without too many verbs), + # using a higher SV cosine score threshold (0.7), and without + # using a default place (if query doesn't specify places). # Ex, if multiple verbs present, treat as action query and do not fulfill. MODE = 'mode' @@ -46,9 +48,18 @@ class DCNames(str, Enum): class QueryMode(str, Enum): + # NOTE: This mode is incompatible with LLM detector STRICT = 'strict' +# Get the SV score threshold for the given mode. +def sv_threshold(mode: str) -> bool: + if mode == QueryMode.STRICT: + return constants.SV_SCORE_HIGH_CONFIDENCE_THRESHOLD + else: + return constants.SV_SCORE_DEFAULT_THRESHOLD + + def is_sdg(insight_ctx: Dict) -> bool: return insight_ctx.get( Params.DC.value) in [DCNames.SDG_DC.value, DCNames.SDG_MINI_DC.value] diff --git a/server/lib/nl/common/commentary.py b/server/lib/nl/common/commentary.py index 46086e6671..a7fbac33ec 100644 --- a/server/lib/nl/common/commentary.py +++ b/server/lib/nl/common/commentary.py @@ -19,13 +19,15 @@ from server.lib.nl.common.utterance import FulfillmentResult from server.lib.nl.common.utterance import Utterance from server.lib.nl.detection.utils import get_top_sv_score +from shared.lib.constants import SV_SCORE_HIGH_CONFIDENCE_THRESHOLD # # List of user messages! # -# If the score is below this, then we report low confidence. -LOW_CONFIDENCE_SCORE_REPORT_THRESHOLD = 0.7 +# If the score is below this, then we report low confidence +# (we reuse the threshold we use for determining something is "high confidence") +LOW_CONFIDENCE_SCORE_REPORT_THRESHOLD = SV_SCORE_HIGH_CONFIDENCE_THRESHOLD LOW_CONFIDENCE_SCORE_MESSAGE = \ 'Low confidence in understanding your query. Displaying the closest results.' diff --git a/server/lib/nl/detection/heuristic_detector.py b/server/lib/nl/detection/heuristic_detector.py index c2cfc83a35..6ee3185505 100644 --- a/server/lib/nl/detection/heuristic_detector.py +++ b/server/lib/nl/detection/heuristic_detector.py @@ -16,6 +16,7 @@ import logging from typing import Dict +from server.lib.explore import params from server.lib.explore.params import QueryMode import server.lib.nl.common.counters as ctr from server.lib.nl.detection import heuristic_classifiers @@ -43,17 +44,20 @@ def detect(place_detector_type: PlaceDetectorType, orig_query: str, query = place_detection.query_without_place_substr + sv_threshold = params.sv_threshold(mode) # Step 3: Identify the SV matched based on the query. svs_scores_dict = dutils.empty_svs_score_dict() try: svs_scores_dict = variable.detect_svs( - query, index_type, query_detection_debug_logs["query_transformations"]) + query, index_type, query_detection_debug_logs["query_transformations"], + sv_threshold) except ValueError as e: logging.info(e) logging.info("Using an empty svs_scores_dict") # Set the SVDetection. - sv_detection = dutils.create_sv_detection(query, svs_scores_dict) + sv_detection = dutils.create_sv_detection(query, svs_scores_dict, + sv_threshold) # Step 4: find query classifiers. classifications = [ diff --git a/server/lib/nl/detection/llm_fallback.py b/server/lib/nl/detection/llm_fallback.py index 54e97d6500..86a2347f55 100644 --- a/server/lib/nl/detection/llm_fallback.py +++ b/server/lib/nl/detection/llm_fallback.py @@ -111,7 +111,7 @@ def need_llm(heuristic: Detection, prev_uttr: Utterance, def _has_no_sv(d: Detection, ctr: counters.Counters) -> bool: - return not dutils.filter_svs(d.svs_detected.single_sv, ctr) + return not dutils.filter_svs(d.svs_detected, ctr) def _has_no_place(d: Detection) -> bool: diff --git a/server/lib/nl/detection/types.py b/server/lib/nl/detection/types.py index c222df3de8..0c98e70543 100644 --- a/server/lib/nl/detection/types.py +++ b/server/lib/nl/detection/types.py @@ -21,6 +21,7 @@ from typing import Dict, List, Optional from shared.lib import detected_variables as dvars +from shared.lib.constants import SV_SCORE_DEFAULT_THRESHOLD @dataclass @@ -57,6 +58,8 @@ class SVDetection: single_sv: dvars.VarCandidates # Multi SV detection. multi_sv: dvars.MultiVarCandidates + # Input SV Threshold + sv_threshold: float = SV_SCORE_DEFAULT_THRESHOLD class RankingType(IntEnum): diff --git a/server/lib/nl/detection/utils.py b/server/lib/nl/detection/utils.py index 1ad550207c..6eb89c7c15 100644 --- a/server/lib/nl/detection/utils.py +++ b/server/lib/nl/detection/utils.py @@ -29,20 +29,18 @@ from shared.lib import constants as shared_constants from shared.lib import detected_variables as dvars -# We will ignore SV detections that are below this threshold -_SV_THRESHOLD = 0.5 - # # Filter out SVs that are below a score. # -def filter_svs(sv: dvars.VarCandidates, counters: ctr.Counters) -> List[str]: +def filter_svs(detection: SVDetection, counters: ctr.Counters) -> List[str]: i = 0 ans = [] blocked_vars = set() - while (i < len(sv.svs)): - if (sv.scores[i] >= _SV_THRESHOLD): - var = sv.svs[i] + single_sv = detection.single_sv + while i < len(single_sv.svs): + if single_sv.scores[i] >= detection.sv_threshold: + var = single_sv.svs[i] # Check if an earlier var blocks this var. if var in blocked_vars: @@ -132,14 +130,19 @@ def empty_svs_score_dict(): return {"SV": [], "CosineScore": [], "SV_to_Sentences": {}, "MultiSV": {}} -def create_sv_detection(query: str, svs_scores_dict: Dict) -> SVDetection: +def create_sv_detection( + query: str, + svs_scores_dict: Dict, + sv_threshold: float = shared_constants.SV_SCORE_DEFAULT_THRESHOLD +) -> SVDetection: return SVDetection(query=query, single_sv=dvars.VarCandidates( svs=svs_scores_dict['SV'], scores=svs_scores_dict['CosineScore'], sv2sentences=svs_scores_dict['SV_to_Sentences']), multi_sv=dvars.dict_to_multivar_candidates( - svs_scores_dict['MultiSV'])) + svs_scores_dict['MultiSV']), + sv_threshold=sv_threshold) def empty_place_detection() -> PlaceDetection: @@ -155,7 +158,7 @@ def create_utterance(query_detection: Detection, counters: ctr.Counters, session_id: str, test: str = '') -> Utterance: - filtered_svs = filter_svs(query_detection.svs_detected.single_sv, counters) + filtered_svs = filter_svs(query_detection.svs_detected, counters) # Construct Utterance datastructure. uttr = Utterance(prev_utterance=currentUtterance, diff --git a/server/lib/nl/detection/variable.py b/server/lib/nl/detection/variable.py index 3d0cd8210e..f868ca14be 100644 --- a/server/lib/nl/detection/variable.py +++ b/server/lib/nl/detection/variable.py @@ -18,6 +18,7 @@ from typing import Dict, List, Union from server.services import datacommons as dc +from shared.lib.constants import SV_SCORE_DEFAULT_THRESHOLD import shared.lib.utils as shared_utils # TODO: decouple words removal from detected attributes. Today, the removal @@ -32,8 +33,12 @@ # calls the NL Server and returns a dict with both single-SV and multi-SV # (if relevant) detections. For more details see create_sv_detection(). # -def detect_svs(query: str, index_type: str, - debug_logs: Dict) -> Dict[str, Union[Dict, List]]: +def detect_svs( + query: str, + index_type: str, + debug_logs: Dict, + threshold: float = SV_SCORE_DEFAULT_THRESHOLD +) -> Dict[str, Union[Dict, List]]: # Remove stop words. # Check comment at the top of this file above `ALL_STOP_WORDS` to understand # the potential areas for improvement. For now, this removal blanket removes @@ -45,4 +50,4 @@ def detect_svs(query: str, index_type: str, shared_utils.remove_stop_words(query, ALL_STOP_WORDS) # Make API call to the NL models/embeddings server. - return dc.nl_search_sv(query, index_type) + return dc.nl_search_sv(query, index_type, threshold) diff --git a/server/routes/nl/helpers.py b/server/routes/nl/helpers.py index 65bd65b511..c3d454a7b9 100644 --- a/server/routes/nl/helpers.py +++ b/server/routes/nl/helpers.py @@ -72,14 +72,6 @@ def parse_query_and_detect(request: Dict, app: str, debug_logs: Dict): # i18n param i18n_str = request.args.get(params.Params.I18N.value, '') i18n = i18n_str and i18n_str.lower() == 'true' - # TODO: Deprecate USE_DEFAULT_PLACE param once 'mode=strict' is in use. - # use default place param - udp_str = request.args.get(params.Params.USE_DEFAULT_PLACE.value, 'true') - udp = udp_str and udp_str.lower() == 'true' - # mode param - mode = request.args.get(params.Params.MODE.value, '') - if mode == QueryMode.STRICT: - udp = True # Index-type default is in nl_server. embeddings_index_type = request.args.get('idx', '') @@ -102,6 +94,14 @@ def parse_query_and_detect(request: Dict, app: str, debug_logs: Dict): default=RequestedDetectorType.HybridSafetyCheck.value, type=str) + # mode param + use_default_place = True + mode = request.args.get(params.Params.MODE.value, '') + if mode == QueryMode.STRICT: + # Strict mode is compatible only with Heuristic Detector! + detector_type = RequestedDetectorType.Heuristic.value + use_default_place = False + place_detector_type = request.args.get('place_detector', default='dc', type=str).lower() @@ -185,7 +185,7 @@ def parse_query_and_detect(request: Dict, app: str, debug_logs: Dict): if utterance: utterance.i18n_lang = i18n_lang - context.merge_with_context(utterance, is_sdg, udp) + context.merge_with_context(utterance, is_sdg, use_default_place) return utterance, None diff --git a/server/services/datacommons.py b/server/services/datacommons.py index c00dfeaff9..a1b96bbc17 100644 --- a/server/services/datacommons.py +++ b/server/services/datacommons.py @@ -329,9 +329,9 @@ def resolve(nodes, prop): return post(url, {'nodes': nodes, 'property': prop}) -def nl_search_sv(query, index_type): +def nl_search_sv(query, index_type, threshold): """Search sv from NL server.""" - url = f'{current_app.config["NL_ROOT"]}/api/search_sv?q={query}&idx={index_type}' + url = f'{current_app.config["NL_ROOT"]}/api/search_sv?q={query}&idx={index_type}&threshold={threshold}' return get(url) diff --git a/shared/lib/constants.py b/shared/lib/constants.py index bb69caa8b6..71c2696296 100644 --- a/shared/lib/constants.py +++ b/shared/lib/constants.py @@ -403,6 +403,14 @@ "school", ]) +# The default Cosine score threshold beyond which Stat Vars +# are considered as a match. +SV_SCORE_DEFAULT_THRESHOLD = 0.5 + +# The default Cosine score threshold beyond which Stat Vars +# are considered a high confidence match. +SV_SCORE_HIGH_CONFIDENCE_THRESHOLD = 0.7 + # A cosine score differential we use to indicate if scores # that differ by up to this amount are "near" SVs. # In Multi-SV detection, if the difference between successive scores exceeds diff --git a/static/js/apps/explore/app.tsx b/static/js/apps/explore/app.tsx index 8f1b6e96f7..4e0bc71a1c 100644 --- a/static/js/apps/explore/app.tsx +++ b/static/js/apps/explore/app.tsx @@ -261,7 +261,6 @@ export function App(props: { isDemo: boolean }): JSX.Element { const llmApi = getSingleParam(hashParams[URL_HASH_PARAMS.LLM_API]); const testMode = getSingleParam(hashParams[URL_HASH_PARAMS.TEST_MODE]); const i18n = getSingleParam(hashParams[URL_HASH_PARAMS.I18N]); - const udp = getSingleParam(hashParams[URL_HASH_PARAMS.USE_DEFAULT_PLACE]); let fulfillmentPromise: Promise; const gaTitle = query @@ -283,8 +282,7 @@ export function App(props: { isDemo: boolean }): JSX.Element { detector, llmApi, testMode, - i18n, - udp + i18n ) .then((resp) => { processFulfillData(resp, false); @@ -305,8 +303,7 @@ export function App(props: { isDemo: boolean }): JSX.Element { [], disableExploreMore, testMode, - i18n, - udp + i18n ) .then((resp) => { processFulfillData(resp, true); @@ -336,8 +333,7 @@ const fetchFulfillData = async ( classificationsJson: any, disableExploreMore: string, testMode: string, - i18n: string, - udp: string + i18n: string ) => { try { const argsMap = new Map(); @@ -347,9 +343,6 @@ const fetchFulfillData = async ( if (i18n) { argsMap.set(URL_HASH_PARAMS.I18N, i18n); } - if (udp) { - argsMap.set(URL_HASH_PARAMS.USE_DEFAULT_PLACE, udp); - } const args = argsMap.size > 0 ? `?${generateArgsParams(argsMap)}` : ""; const startTime = window.performance ? window.performance.now() : undefined; const resp = await axios.post(`/api/explore/fulfill${args}`, { @@ -390,8 +383,7 @@ const fetchDetectAndFufillData = async ( detector: string, llmApi: string, testMode: string, - i18n: string, - udp: string + i18n: string ) => { const argsMap = new Map(); if (detector) { @@ -406,9 +398,6 @@ const fetchDetectAndFufillData = async ( if (i18n) { argsMap.set(URL_HASH_PARAMS.I18N, i18n); } - if (udp) { - argsMap.set(URL_HASH_PARAMS.USE_DEFAULT_PLACE, udp); - } const args = argsMap.size > 0 ? `&${generateArgsParams(argsMap)}` : ""; try { const startTime = window.performance ? window.performance.now() : undefined; diff --git a/static/js/apps/nl_interface/app_state.ts b/static/js/apps/nl_interface/app_state.ts index a41c4e9d1a..edd4cd351a 100644 --- a/static/js/apps/nl_interface/app_state.ts +++ b/static/js/apps/nl_interface/app_state.ts @@ -35,6 +35,7 @@ import Papa from "papaparse"; import { NL_DETECTOR_VALS, NL_INDEX_VALS, + NL_MODE_VALS, NL_PLACE_DETECTOR_VALS, NL_URL_PARAMS, } from "../../constants/app/nl_interface_constants"; @@ -118,6 +119,7 @@ interface NLAppConfig { */ hideFeedbackButtons: boolean; + mode: string; detector: string; indexType: string; placeholderQuery: string; @@ -188,6 +190,7 @@ const nlAppModel: NLAppModel = { autoPlayDisableTypingAnimation: false, currentNlQueryContextId: null, hideFeedbackButtons: false, + mode: "", detector: NL_DETECTOR_VALS.HYBRID, indexType: NL_INDEX_VALS.MEDIUM_FT, placeholderQuery: "family earnings in california", @@ -224,6 +227,7 @@ const nlAppActions: NLAppActions = { autoPlayDisableTypingAnimation: !!getUrlToken("d"), autoPlayManuallyShowQuery: !!getUrlToken("m"), hideFeedbackButtons: !!getUrlToken("enable_demo"), + mode: getUrlTokenOrDefault(NL_URL_PARAMS.MODE, ""), detector: getUrlTokenOrDefault( NL_URL_PARAMS.DETECTOR, NL_DETECTOR_VALS.HYBRID @@ -303,6 +307,9 @@ const nlAppActions: NLAppActions = { if (config.placeDetector) { params["place_detector"] = config.placeDetector; } + if (config.mode) { + params["mode"] = config.mode; + } const start = Date.now(); try { const resp = await axios.post( diff --git a/static/js/constants/app/explore_constants.ts b/static/js/constants/app/explore_constants.ts index befe78fb1d..18f2982e3f 100644 --- a/static/js/constants/app/explore_constants.ts +++ b/static/js/constants/app/explore_constants.ts @@ -37,7 +37,6 @@ export const URL_HASH_PARAMS = { MAXIMUM_BLOCK: "mb", TEST_MODE: "test", I18N: "i18n", - USE_DEFAULT_PLACE: "udp", }; // Dcid of the default topic to use export const DEFAULT_TOPIC = "dc/topic/Root"; diff --git a/static/js/constants/app/nl_interface_constants.ts b/static/js/constants/app/nl_interface_constants.ts index e6d13a6cef..9705e2b597 100644 --- a/static/js/constants/app/nl_interface_constants.ts +++ b/static/js/constants/app/nl_interface_constants.ts @@ -34,6 +34,7 @@ export const NL_SOURCE_REPLACEMENTS = { export const NL_URL_PARAMS = { DETECTOR: "detector", IDX: "idx", + MODE: "mode", PLACE_DETECTOR: "place_detector", }; @@ -48,6 +49,10 @@ export const NL_DETECTOR_VALS = { LLM: "llm", }; +export const NL_MODE_VALS = { + STRICT: "strict", +}; + export const NL_PLACE_DETECTOR_VALS = { NER: "ner", DC: "dc", diff --git a/static/src/server.ts b/static/src/server.ts index b38bb726f2..8d7f5de70f 100644 --- a/static/src/server.ts +++ b/static/src/server.ts @@ -465,11 +465,9 @@ app.get("/nodejs/query", (req: Request, res: Response) => { const urlRoot = `${req.protocol}://${req.get("host")}`; res.setHeader("Content-Type", "application/json"); axios - // Use "udp=false" to disable using default place. - .post( - `${CONFIG.apiRoot}/api/nl/data?q=${query}&detector=heuristic&udp=false&mode=strict`, - {} - ) + // Set "mode=strict" to use heuristic detector, disable using default place, + // use a higher SV threshold and avoid multi-verb queries + .post(`${CONFIG.apiRoot}/api/nl/data?q=${query}&mode=strict`, {}) .then((resp) => { const nlResultTime = process.hrtime.bigint(); const mainPlace = resp.data["place"] || {};