diff --git a/src/backend/apidb/readonly_pgsql_selection.cpp b/src/backend/apidb/readonly_pgsql_selection.cpp index 1e3c9ed89..e06a6469d 100644 --- a/src/backend/apidb/readonly_pgsql_selection.cpp +++ b/src/backend/apidb/readonly_pgsql_selection.cpp @@ -11,6 +11,8 @@ #include #include +#include + namespace po = boost::program_options; using std::set; @@ -516,12 +518,20 @@ bool readonly_pgsql_selection::is_user_blocked(const osm_user_id_t id) { return !res.empty(); } -bool readonly_pgsql_selection::get_user_id_pass(const std::string& display_name, osm_user_id_t & id, +bool readonly_pgsql_selection::get_user_id_pass(const std::string& user_name, osm_user_id_t & id, std::string & pass_crypt, std::string & pass_salt) { - auto res = w.prepared("get_user_id_pass")(display_name).exec(); - if (res.empty()) - return false; + std::string email = boost::algorithm::trim_copy(user_name); + + auto res = w.prepared("get_user_id_pass")(email)(user_name).exec(); + + if (res.empty()) { + // try case insensitive query + res = w.prepared("get_user_id_pass_case_insensitive")(email)(user_name).exec(); + // failure, in case no entries or multiple entries were found + if (res.size() != 1) + return false; + } auto row = res[0]; id = row["id"].as(); @@ -835,9 +845,16 @@ readonly_pgsql_selection::factory::factory(const po::variables_map &opts) AND (needs_view or ends_at > (now() at time zone 'utc')) LIMIT 1 )"); m_connection.prepare("get_user_id_pass", - R"(SELECT id, pass_crypt, pass_salt FROM users - WHERE display_name = $1 - AND (status = 'active' or status = 'confirmed') )"); + R"(SELECT id, pass_crypt, pass_salt FROM users + WHERE (email = $1 OR display_name = $2) + AND (status = 'active' or status = 'confirmed') LIMIT 1 + )"); + + m_connection.prepare("get_user_id_pass_case_insensitive", + R"(SELECT id, pass_crypt, pass_salt FROM users + WHERE (LOWER(email) = LOWER($1) OR LOWER(display_name) = LOWER($2)) + AND (status = 'active' or status = 'confirmed') + )"); // clang-format on } diff --git a/src/backend/apidb/writeable_pgsql_selection.cpp b/src/backend/apidb/writeable_pgsql_selection.cpp index bceb91d95..4e8b6577e 100644 --- a/src/backend/apidb/writeable_pgsql_selection.cpp +++ b/src/backend/apidb/writeable_pgsql_selection.cpp @@ -14,9 +14,11 @@ #include #include +#include #include #include + namespace po = boost::program_options; using std::set; using std::list; @@ -379,13 +381,20 @@ bool writeable_pgsql_selection::is_user_blocked(const osm_user_id_t id) { return !res.empty(); } -bool writeable_pgsql_selection::get_user_id_pass(const std::string& display_name, osm_user_id_t & id, +bool writeable_pgsql_selection::get_user_id_pass(const std::string& user_name, osm_user_id_t & id, std::string & pass_crypt, std::string & pass_salt) { - auto res = w.prepared("get_user_id_pass")(display_name).exec(); + std::string email = boost::algorithm::trim_copy(user_name); + + auto res = w.prepared("get_user_id_pass")(email)(user_name).exec(); - if (res.empty()) - return false; + if (res.empty()) { + // try case insensitive query + res = w.prepared("get_user_id_pass_case_insensitive")(email)(user_name).exec(); + // failure, in case no entries or multiple entries were found + if (res.size() != 1) + return false; + } auto row = res[0]; id = row["id"].as(); @@ -823,9 +832,16 @@ writeable_pgsql_selection::factory::factory(const po::variables_map &opts) AND (needs_view or ends_at > (now() at time zone 'utc')) LIMIT 1 )"); m_connection.prepare("get_user_id_pass", - R"(SELECT id, pass_crypt, pass_salt FROM users - WHERE display_name = $1 - AND (status = 'active' or status = 'confirmed') )"); + R"(SELECT id, pass_crypt, pass_salt FROM users + WHERE (email = $1 OR display_name = $2) + AND (status = 'active' or status = 'confirmed') LIMIT 1 + )"); + + m_connection.prepare("get_user_id_pass_case_insensitive", + R"(SELECT id, pass_crypt, pass_salt FROM users + WHERE (LOWER(email) = LOWER($1) OR LOWER(display_name) = LOWER($2)) + AND (status = 'active' or status = 'confirmed') + )"); // clang-format on } diff --git a/src/basicauth.cpp b/src/basicauth.cpp index 5bbf75b4b..7854738a4 100644 --- a/src/basicauth.cpp +++ b/src/basicauth.cpp @@ -101,7 +101,7 @@ namespace basicauth { { PasswordHash pwd_hash; - std::string display_name; + std::string user_name; std::string candidate; osm_user_id_t user_id; @@ -143,16 +143,16 @@ namespace basicauth { return boost::optional{}; try { - display_name = auth.substr(0, pos); + user_name = auth.substr(0, pos); candidate = auth.substr(pos + 1); } catch (std::out_of_range&) { return boost::optional{}; } - if (display_name.empty() || candidate.empty()) + if (user_name.empty() || candidate.empty()) return boost::optional{}; - auto user_exists = selection->get_user_id_pass(display_name, user_id, pass_crypt, pass_salt); + auto user_exists = selection->get_user_id_pass(user_name, user_id, pass_crypt, pass_salt); if (!user_exists) throw http::unauthorized("Incorrect user or password"); diff --git a/test/test_basicauth.cpp b/test/test_basicauth.cpp index 5c586dd65..6a320ae93 100644 --- a/test/test_basicauth.cpp +++ b/test/test_basicauth.cpp @@ -88,10 +88,10 @@ class basicauth_test_data_selection int select_changesets(const std::vector &) { return 0; } void select_changeset_discussions() {} - bool get_user_id_pass(const std::string& display_name, osm_user_id_t & user_id, + bool get_user_id_pass(const std::string& user_name, osm_user_id_t & user_id, std::string & pass_crypt, std::string & pass_salt) { - if (display_name == "demo") { + if (user_name == "demo") { user_id = 4711; pass_crypt = "3wYbPiOxk/tU0eeIDjUhdvi8aDP3AbFtwYKKxF1IhGg="; pass_salt = "sha512!10000!OUQLgtM7eD8huvanFT5/WtWaCwdOdrir8QOtFwxhO0A=";