Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(search_family): Add options test for the FT.AGGREGATE command #4479

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/facade/cmd_arg_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ struct CmdArgParser {

// Check if the next value is equal to a specific tag. If equal, its consumed.
template <class... Args> bool Check(std::string_view tag, Args*... args) {
if (cur_i_ + sizeof...(Args) >= args_.size())
if (cur_i_ + sizeof...(Args) >= args_.size() || error_)
return false;

std::string_view arg = SafeSV(cur_i_);
Expand Down
3 changes: 2 additions & 1 deletion src/server/search/doc_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ struct SearchParams {
Only one of load_fields and return_fields should be set.
*/
std::optional<SearchFieldsList> load_fields;
bool no_content = false;

std::optional<search::SortOption> sort_option;
search::QueryParams query_params;
Expand All @@ -157,7 +158,7 @@ struct SearchParams {
}

bool IdsOnly() const {
return return_fields && return_fields->empty();
return no_content || (return_fields && return_fields->empty());
}

bool ShouldReturnField(std::string_view alias) const;
Expand Down
71 changes: 36 additions & 35 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ void ParseLoadFields(CmdArgParser* parser, std::optional<SearchFieldsList>* load
load_fields->emplace();
}

while (num_fields--) {
while (parser->HasNext() && num_fields--) {
string_view str = parser->Next();

if (absl::StartsWith(str, "@"sv)) {
Expand Down Expand Up @@ -274,7 +274,7 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyB
* AS a */
size_t num_fields = parser->Next<size_t>();
params.return_fields.emplace();
while (params.return_fields->size() < num_fields) {
while (parser->HasNext() && params.return_fields->size() < num_fields) {
StringOrView name = StringOrView::FromString(parser->Next<std::string>());

if (parser->Check("AS")) {
Expand All @@ -285,8 +285,7 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyB
}
}
} else if (parser->Check("NOCONTENT")) { // NOCONTENT
params.load_fields.emplace();
params.return_fields.emplace();
params.no_content = true;
} else if (parser->Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector]
params.query_params = ParseQueryParams(parser);
} else if (parser->Check("SORTBY")) {
Expand Down Expand Up @@ -342,26 +341,26 @@ std::optional<aggregate::SortParams> ParseAggregatorSortParams(CmdArgParser* par
return sort_params;
}

optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser* parser,
SinkReplyBuilder* builder) {
AggregateParams params;
tie(params.index, params.query) = parser.Next<string_view, string_view>();
tie(params.index, params.query) = parser->Next<string_view, string_view>();

// Parse LOAD count field [field ...]
// LOAD options are at the beginning of the query, so we need to parse them first
while (parser.HasNext() && parser.Check("LOAD")) {
ParseLoadFields(&parser, &params.load_fields);
while (parser->HasNext() && parser->Check("LOAD")) {
ParseLoadFields(parser, &params.load_fields);
}

while (parser.HasNext()) {
while (parser->HasNext()) {
// GROUPBY nargs property [property ...]
if (parser.Check("GROUPBY")) {
size_t num_fields = parser.Next<size_t>();
if (parser->Check("GROUPBY")) {
size_t num_fields = parser->Next<size_t>();

std::vector<std::string> fields;
fields.reserve(num_fields);
while (num_fields > 0 && parser.HasNext()) {
auto parsed_field = ParseFieldWithAtSign(&parser);
while (parser->HasNext() && num_fields > 0) {
auto parsed_field = ParseFieldWithAtSign(parser);

/*
TODO: Throw an error if the field has no '@' sign at the beginning
Expand All @@ -376,27 +375,27 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
}

vector<aggregate::Reducer> reducers;
while (parser.Check("REDUCE")) {
while (parser->Check("REDUCE")) {
using RF = aggregate::ReducerFunc;
auto func_name =
parser.TryMapNext("COUNT", RF::COUNT, "COUNT_DISTINCT", RF::COUNT_DISTINCT, "SUM",
RF::SUM, "AVG", RF::AVG, "MAX", RF::MAX, "MIN", RF::MIN);
parser->TryMapNext("COUNT", RF::COUNT, "COUNT_DISTINCT", RF::COUNT_DISTINCT, "SUM",
RF::SUM, "AVG", RF::AVG, "MAX", RF::MAX, "MIN", RF::MIN);

if (!func_name) {
builder->SendError(absl::StrCat("reducer function ", parser.Next(), " not found"));
builder->SendError(absl::StrCat("reducer function ", parser->Next(), " not found"));
return nullopt;
}

auto func = aggregate::FindReducerFunc(*func_name);
auto nargs = parser.Next<size_t>();
auto nargs = parser->Next<size_t>();

string source_field;
if (nargs > 0) {
source_field = ParseField(&parser);
source_field = ParseField(parser);
}

parser.ExpectTag("AS");
string result_field = parser.Next<string>();
parser->ExpectTag("AS");
string result_field = parser->Next<string>();

reducers.push_back(
aggregate::Reducer{std::move(source_field), std::move(result_field), std::move(func)});
Expand All @@ -407,8 +406,8 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
}

// SORTBY nargs
if (parser.Check("SORTBY")) {
auto sort_params = ParseAggregatorSortParams(&parser);
if (parser->Check("SORTBY")) {
auto sort_params = ParseAggregatorSortParams(parser);
if (!sort_params) {
builder->SendError("bad arguments for SORTBY: specified invalid number of strings");
return nullopt;
Expand All @@ -419,29 +418,24 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
}

// LIMIT
if (parser.Check("LIMIT")) {
auto [offset, num] = parser.Next<size_t, size_t>();
if (parser->Check("LIMIT")) {
auto [offset, num] = parser->Next<size_t, size_t>();
params.steps.push_back(aggregate::MakeLimitStep(offset, num));
continue;
}

// PARAMS
if (parser.Check("PARAMS")) {
params.params = ParseQueryParams(&parser);
if (parser->Check("PARAMS")) {
params.params = ParseQueryParams(parser);
continue;
}

if (parser.Check("LOAD")) {
if (parser->Check("LOAD")) {
builder->SendError("LOAD cannot be applied after projectors or reducers");
return nullopt;
}

builder->SendError(absl::StrCat("Unknown clause: ", parser.Peek()));
return nullopt;
}

if (auto err = parser.Error(); err) {
builder->SendError(err->MakeReply());
builder->SendError(absl::StrCat("Unknown clause: ", parser->Peek()));
return nullopt;
}

Expand Down Expand Up @@ -995,11 +989,18 @@ void SearchFamily::FtTagVals(CmdArgList args, const CommandContext& cmd_cntx) {
}

void SearchFamily::FtAggregate(CmdArgList args, const CommandContext& cmd_cntx) {
CmdArgParser parser{args};
auto* builder = cmd_cntx.rb;
const auto params = ParseAggregatorParamsOrReply(args, builder);

const auto params = ParseAggregatorParamsOrReply(&parser, builder);
if (!params)
return;

if (auto err = parser.Error(); err) {
builder->SendError(err->MakeReply());
return;
}

search::SearchAlgorithm search_algo;
if (!search_algo.Init(params->query, &params->params, nullptr))
return builder->SendError("Query syntax error");
Expand Down
138 changes: 131 additions & 7 deletions src/server/search/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

#include "base/gtest.h"
#include "base/logging.h"
#include "facade/error.h"
#include "facade/facade_test.h"
#include "server/command_registry.h"
#include "server/test_utils.h"

using namespace testing;
using namespace std;
using namespace util;
using namespace facade;

namespace dfly {

Expand Down Expand Up @@ -630,10 +632,10 @@ TEST_F(SearchFamilyTest, TestReturn) {

// Check no fields are returned
resp = Run({"ft.search", "i1", "@justA:0", "return", "0"});
EXPECT_THAT(resp, RespArray(ElementsAre(IntArg(1), "k0")));
EXPECT_THAT(resp, IsArray(IntArg(1), "k0"));

resp = Run({"ft.search", "i1", "@justA:0", "nocontent"});
EXPECT_THAT(resp, RespArray(ElementsAre(IntArg(1), "k0")));
EXPECT_THAT(resp, IsArray(IntArg(1), "k0"));

// Check only one field is returned (and with original identifier)
resp = Run({"ft.search", "i1", "@justA:0", "return", "1", "longA"});
Expand Down Expand Up @@ -858,7 +860,7 @@ TEST_F(SearchFamilyTest, FtProfileErrorReply) {
EXPECT_THAT(resp, ErrArg("no `SEARCH` or `AGGREGATE` provided"));

resp = Run({"ft.profile", "i1", "search", "not_query", "(a | b) c d"});
EXPECT_THAT(resp, ErrArg("syntax error"));
EXPECT_THAT(resp, ErrArg(kSyntaxErr));

resp = Run({"ft.profile", "non_existent_key", "search", "query", "(a | b) c d"});
EXPECT_THAT(resp, ErrArg("non_existent_key: no such index"));
Expand Down Expand Up @@ -1805,15 +1807,15 @@ TEST_F(SearchFamilyTest, AggregateSortByParsingErrors) {

// Test SORTBY with negative argument count
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "-3", "@name", "@number", "DESC"});
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

// Test MAX with invalid value
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "-10"});
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

// Test MAX without a value
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX"});
EXPECT_THAT(resp, ErrArg("syntax error"));
EXPECT_THAT(resp, ErrArg(kSyntaxErr));

// Test SORTBY with a non-existing field
/* Temporary unsupported
Expand All @@ -1822,7 +1824,129 @@ TEST_F(SearchFamilyTest, AggregateSortByParsingErrors) {

// Test SORTBY with an invalid value
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "notvalue", "@name"});
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));
}

TEST_F(SearchFamilyTest, InvalidSearchOptions) {
Run({"JSON.SET", "j1", ".", R"({"field1":"first","field2":"second"})"});
Run({"FT.CREATE", "idx", "ON", "JSON", "SCHEMA", "$.field1", "AS", "field1", "TEXT", "$.field2",
"AS", "field2", "TEXT"});

/* Test with an empty query and LOAD. TODO: Add separate test for query syntax
auto resp = Run({"FT.SEARCH", "idx", "", "LOAD", "1", "@field1"});
EXPECT_THAT(resp, IsMapWithSize()); */

// Test with LIMIT missing arguments
auto resp = Run({"FT.SEARCH", "idx", "*", "LIMIT", "0"});
EXPECT_THAT(resp, ErrArg(kSyntaxErr));

// Test with LIMIT exceeding the maximum allowed value
resp = Run({"FT.SEARCH", "idx", "*", "LIMIT", "0", "100000000000000000000"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

// Test with LIMIT and negative arguments
resp = Run({"FT.SEARCH", "idx", "*", "LIMIT", "-1", "10"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

// Test with LIMIT and invalid argument types
resp = Run({"FT.SEARCH", "idx", "*", "LIMIT", "start", "count"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

// Test with invalid LOAD arguments
resp = Run({"FT.SEARCH", "idx", "*", "LOAD", "@field1", "@field2"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

// Test with duplicate fields in LOAD
resp = Run({"FT.SEARCH", "idx", "*", "LOAD", "4", "@field1", "@field1", "@field2", "@field2"});
EXPECT_THAT(resp, IsMapWithSize("j1", IsMap("field1", "\"first\"", "field2", "\"second\"", "$",
R"({"field1":"first","field2":"second"})")));

// Test with LOAD exceeding maximum allowed count
resp = Run({"FT.SEARCH", "idx", "*", "LOAD", "100000000000000000000", "@field1", "@field2"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

// Test with invalid RETURN syntax (missing count)
resp = Run({"FT.SEARCH", "idx", "*", "RETURN", "@field1", "@field2"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

// Test with RETURN having duplicate fields
resp = Run({"FT.SEARCH", "idx", "*", "RETURN", "4", "field1", "field1", "field2", "field2"});
EXPECT_THAT(resp, IsMapWithSize("j1", IsMap("field1", "\"first\"", "field2", "\"second\"")));

// Test with RETURN exceeding maximum allowed count
resp = Run({"FT.SEARCH", "idx", "*", "RETURN", "100000000000000000000", "@field1", "@field2"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

// Test with NOCONTENT and LOAD
resp = Run({"FT.SEARCH", "idx", "*", "NOCONTENT", "LOAD", "2", "@field1", "@field2"});
EXPECT_THAT(resp, IsArray(IntArg(1), "j1"));

// Test with NOCONTENT and RETURN
resp = Run({"FT.SEARCH", "idx", "*", "NOCONTENT", "RETURN", "2", "@field1", "@field2"});
EXPECT_THAT(resp, IsArray(IntArg(1), "j1"));
}

TEST_F(SearchFamilyTest, InvalidAggregateOptions) {
Run({"JSON.SET", "j1", ".", R"({"field1":"first","field2":"second"})"});
Run({"FT.CREATE", "idx", "ON", "JSON", "SCHEMA", "$.field1", "AS", "field1", "TEXT", "$.field2",
"AS", "field2", "TEXT"});

// Test GROUPBY with no arguments
auto resp = Run({"FT.AGGREGATE", "idx", "*", "GROUPBY"});
EXPECT_THAT(resp, ErrArg(kSyntaxErr));

// Test GROUPBY with invalid count
resp = Run({"FT.AGGREGATE", "idx", "*", "GROUPBY", "-1", "@field1"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

resp =
Run({"FT.AGGREGATE", "idx", "*", "GROUPBY", "100000000000000000000", "@field1", "@field2"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

// Test REDUCE with no REDUCE function
/* resp = Run({"FT.AGGREGATE", "idx", "*", "GROUPBY", "1", "@field1", "REDUCE"});
EXPECT_THAT(resp, ErrArg("Bad arguments for REDUCE: SUCCESS"));
*/

/* // Test REDUCE with COUNT function
resp = Run({"FT.AGGREGATE", "idx", "*", "GROUPBY", "1", "@field1", "REDUCE", "COUNT", "0"});
EXPECT_THAT(resp, IsMapWithSize("__generated_aliascount", "1", "field1", "first")); */

// Test REDUCE with invalid function
resp = Run({"FT.AGGREGATE", "idx", "*", "GROUPBY", "1", "@field1", "REDUCE", "INVALIDFUNC", "0",
"AS", "result"});
EXPECT_THAT(resp, ErrArg("reducer function INVALIDFUNC not found"));

// Test SORTBY with no arguments
resp = Run({"FT.AGGREGATE", "idx", "*", "SORTBY"});
EXPECT_THAT(resp, ErrArg(kSyntaxErr));

// Test SORTBY with invalid count
resp = Run({"FT.AGGREGATE", "idx", "*", "SORTBY", "-1", "@field1"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

resp = Run({"FT.AGGREGATE", "idx", "*", "SORTBY", "100000000000000000000", "@field1"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

// Test LIMIT with invalid arguments
resp = Run({"FT.AGGREGATE", "idx", "*", "LIMIT", "0"});
EXPECT_THAT(resp, ErrArg(kSyntaxErr));

resp = Run({"FT.AGGREGATE", "idx", "*", "LIMIT", "-1", "10"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

resp = Run({"FT.AGGREGATE", "idx", "*", "LIMIT", "0", "100000000000000000000"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

// Test LOAD with invalid arguments
resp = Run({"FT.AGGREGATE", "idx", "*", "LOAD", "@field1", "@field2"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

resp = Run({"FT.AGGREGATE", "idx", "*", "LOAD", "-1", "@field1"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));

resp = Run({"FT.AGGREGATE", "idx", "*", "LOAD", "100000000000000000000", "@field1", "@field2"});
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));
}

} // namespace dfly
Loading