diff --git a/src/backend/apidb/readonly_pgsql_selection.cpp b/src/backend/apidb/readonly_pgsql_selection.cpp index 1e3c9ed89..e06a6469d 100644 --- a/src/backend/apidb/readonly_pgsql_selection.cpp +++ b/src/backend/apidb/readonly_pgsql_selection.cpp @@ -11,6 +11,8 @@ #include #include +#include + namespace po = boost::program_options; using std::set; @@ -516,12 +518,20 @@ bool readonly_pgsql_selection::is_user_blocked(const osm_user_id_t id) { return !res.empty(); } -bool readonly_pgsql_selection::get_user_id_pass(const std::string& display_name, osm_user_id_t & id, +bool readonly_pgsql_selection::get_user_id_pass(const std::string& user_name, osm_user_id_t & id, std::string & pass_crypt, std::string & pass_salt) { - auto res = w.prepared("get_user_id_pass")(display_name).exec(); - if (res.empty()) - return false; + std::string email = boost::algorithm::trim_copy(user_name); + + auto res = w.prepared("get_user_id_pass")(email)(user_name).exec(); + + if (res.empty()) { + // try case insensitive query + res = w.prepared("get_user_id_pass_case_insensitive")(email)(user_name).exec(); + // failure, in case no entries or multiple entries were found + if (res.size() != 1) + return false; + } auto row = res[0]; id = row["id"].as(); @@ -835,9 +845,16 @@ readonly_pgsql_selection::factory::factory(const po::variables_map &opts) AND (needs_view or ends_at > (now() at time zone 'utc')) LIMIT 1 )"); m_connection.prepare("get_user_id_pass", - R"(SELECT id, pass_crypt, pass_salt FROM users - WHERE display_name = $1 - AND (status = 'active' or status = 'confirmed') )"); + R"(SELECT id, pass_crypt, pass_salt FROM users + WHERE (email = $1 OR display_name = $2) + AND (status = 'active' or status = 'confirmed') LIMIT 1 + )"); + + m_connection.prepare("get_user_id_pass_case_insensitive", + R"(SELECT id, pass_crypt, pass_salt FROM users + WHERE (LOWER(email) = LOWER($1) OR LOWER(display_name) = LOWER($2)) + AND (status = 'active' or status = 'confirmed') + )"); // clang-format on } diff --git a/src/backend/apidb/writeable_pgsql_selection.cpp b/src/backend/apidb/writeable_pgsql_selection.cpp index bceb91d95..4e8b6577e 100644 --- a/src/backend/apidb/writeable_pgsql_selection.cpp +++ b/src/backend/apidb/writeable_pgsql_selection.cpp @@ -14,9 +14,11 @@ #include #include +#include #include #include + namespace po = boost::program_options; using std::set; using std::list; @@ -379,13 +381,20 @@ bool writeable_pgsql_selection::is_user_blocked(const osm_user_id_t id) { return !res.empty(); } -bool writeable_pgsql_selection::get_user_id_pass(const std::string& display_name, osm_user_id_t & id, +bool writeable_pgsql_selection::get_user_id_pass(const std::string& user_name, osm_user_id_t & id, std::string & pass_crypt, std::string & pass_salt) { - auto res = w.prepared("get_user_id_pass")(display_name).exec(); + std::string email = boost::algorithm::trim_copy(user_name); + + auto res = w.prepared("get_user_id_pass")(email)(user_name).exec(); - if (res.empty()) - return false; + if (res.empty()) { + // try case insensitive query + res = w.prepared("get_user_id_pass_case_insensitive")(email)(user_name).exec(); + // failure, in case no entries or multiple entries were found + if (res.size() != 1) + return false; + } auto row = res[0]; id = row["id"].as(); @@ -823,9 +832,16 @@ writeable_pgsql_selection::factory::factory(const po::variables_map &opts) AND (needs_view or ends_at > (now() at time zone 'utc')) LIMIT 1 )"); m_connection.prepare("get_user_id_pass", - R"(SELECT id, pass_crypt, pass_salt FROM users - WHERE display_name = $1 - AND (status = 'active' or status = 'confirmed') )"); + R"(SELECT id, pass_crypt, pass_salt FROM users + WHERE (email = $1 OR display_name = $2) + AND (status = 'active' or status = 'confirmed') LIMIT 1 + )"); + + m_connection.prepare("get_user_id_pass_case_insensitive", + R"(SELECT id, pass_crypt, pass_salt FROM users + WHERE (LOWER(email) = LOWER($1) OR LOWER(display_name) = LOWER($2)) + AND (status = 'active' or status = 'confirmed') + )"); // clang-format on } diff --git a/src/basicauth.cpp b/src/basicauth.cpp index 5bbf75b4b..7854738a4 100644 --- a/src/basicauth.cpp +++ b/src/basicauth.cpp @@ -101,7 +101,7 @@ namespace basicauth { { PasswordHash pwd_hash; - std::string display_name; + std::string user_name; std::string candidate; osm_user_id_t user_id; @@ -143,16 +143,16 @@ namespace basicauth { return boost::optional{}; try { - display_name = auth.substr(0, pos); + user_name = auth.substr(0, pos); candidate = auth.substr(pos + 1); } catch (std::out_of_range&) { return boost::optional{}; } - if (display_name.empty() || candidate.empty()) + if (user_name.empty() || candidate.empty()) return boost::optional{}; - auto user_exists = selection->get_user_id_pass(display_name, user_id, pass_crypt, pass_salt); + auto user_exists = selection->get_user_id_pass(user_name, user_id, pass_crypt, pass_salt); if (!user_exists) throw http::unauthorized("Incorrect user or password"); diff --git a/test/test_apidb_backend_changeset_uploads.cpp b/test/test_apidb_backend_changeset_uploads.cpp index dd2964372..3c1eacbd7 100644 --- a/test/test_apidb_backend_changeset_uploads.cpp +++ b/test/test_apidb_backend_changeset_uploads.cpp @@ -2072,6 +2072,72 @@ namespace { } + // User logging on with display name (different case) + { + // set up request headers from test case + test_request req; + req.set_header("REQUEST_METHOD", "POST"); + req.set_header("REQUEST_URI", "/api/0.6/changeset/1/upload"); + req.set_header("HTTP_AUTHORIZATION", "Basic REVNTzpwYXNzd29yZA=="); + req.set_header("REMOTE_ADDR", "127.0.0.1"); + + req.set_payload(R"( + + + )" ); + + // execute the request + process_request(req, limiter, generator, route, sel_factory, upd_factory, std::shared_ptr(nullptr)); + + if (req.response_status() != 200) + throw std::runtime_error("Expected HTTP 200 OK: Log on with display name, different case"); + } + + // User logging on with email address rather than display name + { + // set up request headers from test case + test_request req; + req.set_header("REQUEST_METHOD", "POST"); + req.set_header("REQUEST_URI", "/api/0.6/changeset/1/upload"); + req.set_header("HTTP_AUTHORIZATION", "Basic ZGVtb0BleGFtcGxlLmNvbTpwYXNzd29yZA=="); + req.set_header("REMOTE_ADDR", "127.0.0.1"); + + req.set_payload(R"( + + + )" ); + + // execute the request + process_request(req, limiter, generator, route, sel_factory, upd_factory, std::shared_ptr(nullptr)); + + if (req.response_status() != 200) + throw std::runtime_error("Expected HTTP 200 OK: Log on with email address"); + } + + + // User logging on with email address with different case and additional whitespace rather than display name + { + // set up request headers from test case + test_request req; + req.set_header("REQUEST_METHOD", "POST"); + req.set_header("REQUEST_URI", "/api/0.6/changeset/1/upload"); + req.set_header("HTTP_AUTHORIZATION", "Basic ICAgZGVtb0BleGFtcGxlLkNPTSAgIDpwYXNzd29yZA=="); + req.set_header("REMOTE_ADDR", "127.0.0.1"); + + req.set_payload(R"( + + + )" ); + + // execute the request + process_request(req, limiter, generator, route, sel_factory, upd_factory, std::shared_ptr(nullptr)); + + if (req.response_status() != 200) + throw std::runtime_error("Expected HTTP 200 OK: Log on with email address, whitespace, different case"); + } + + + // User is blocked (needs_view) { tdb.run_sql(R"(UPDATE user_blocks SET needs_view = true where user_id = 1;)"); diff --git a/test/test_basicauth.cpp b/test/test_basicauth.cpp index 5c586dd65..6a320ae93 100644 --- a/test/test_basicauth.cpp +++ b/test/test_basicauth.cpp @@ -88,10 +88,10 @@ class basicauth_test_data_selection int select_changesets(const std::vector &) { return 0; } void select_changeset_discussions() {} - bool get_user_id_pass(const std::string& display_name, osm_user_id_t & user_id, + bool get_user_id_pass(const std::string& user_name, osm_user_id_t & user_id, std::string & pass_crypt, std::string & pass_salt) { - if (display_name == "demo") { + if (user_name == "demo") { user_id = 4711; pass_crypt = "3wYbPiOxk/tU0eeIDjUhdvi8aDP3AbFtwYKKxF1IhGg="; pass_salt = "sha512!10000!OUQLgtM7eD8huvanFT5/WtWaCwdOdrir8QOtFwxhO0A="; diff --git a/test/test_database.cpp b/test/test_database.cpp index 143383f76..b2c27deeb 100644 --- a/test/test_database.cpp +++ b/test/test_database.cpp @@ -208,7 +208,20 @@ void test_database::run_update( } catch (const std::exception &e) { throw std::runtime_error( - (boost::format("%1%, in update") % e.what()).str()); + (boost::format("%1%, in update, writable selection") % e.what()).str()); + } + + try { + // clear out database before using it! + pqxx::connection conn((boost::format("dbname=%1%") % m_db_name).str()); + conn.perform(truncate_all_tables()); + + m_use_readonly = true; + func(*this); + + } catch (const std::exception &e) { + throw std::runtime_error( + (boost::format("%1%, in update, read-only selection") % e.what()).str()); } }