Skip to content

Commit

Permalink
Workaround for sqlx bug (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdecimus committed Jul 24, 2023
1 parent e8df912 commit 9725aa7
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 20 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/directory/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ scrypt = "0.11.0"
sha1 = "0.10.5"
sha2 = "0.10.6"
md5 = "0.7.0"
futures = "0.3"

[dev-dependencies]
tokio = { version = "1.23", features = ["full"] }
52 changes: 32 additions & 20 deletions crates/directory/src/sql/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* for more details.
*/

use futures::TryStreamExt;
use mail_send::Credentials;
use sqlx::{any::AnyRow, Column, Row};

Expand Down Expand Up @@ -48,18 +49,20 @@ impl Directory for SqlDirectory {
}

async fn principal(&self, name: &str) -> crate::Result<Option<Principal>> {
if let Some(row) = sqlx::query(&self.mappings.query_name)
let result = sqlx::query(&self.mappings.query_name)
.bind(name)
.fetch_optional(&self.pool)
.await?
{
.fetch(&self.pool)
.try_next()
.await?;
if let Some(row) = result {
// Map row to principal
let mut principal = self.mappings.row_to_principal(row)?;

// Obtain members
principal.member_of = sqlx::query_scalar::<_, String>(&self.mappings.query_members)
.bind(name)
.fetch_all(&self.pool)
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await?;

// Check whether the user is a superuser
Expand All @@ -81,22 +84,25 @@ impl Directory for SqlDirectory {
async fn emails_by_name(&self, name: &str) -> crate::Result<Vec<String>> {
sqlx::query_scalar::<_, String>(&self.mappings.query_emails)
.bind(name)
.fetch_all(&self.pool)
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await
.map_err(Into::into)
}

async fn names_by_email(&self, address: &str) -> crate::Result<Vec<String>> {
match sqlx::query_scalar::<_, String>(&self.mappings.query_recipients)
let result = sqlx::query_scalar::<_, String>(&self.mappings.query_recipients)
.bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref())
.fetch_all(&self.pool)
.await
{
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await;
match result {
Ok(ids) if !ids.is_empty() => Ok(ids),
Ok(_) if self.opt.catch_all => {
sqlx::query_scalar::<_, String>(&self.mappings.query_recipients)
.bind(to_catch_all_address(address))
.fetch_all(&self.pool)
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await
.map_err(Into::into)
}
Expand All @@ -106,15 +112,17 @@ impl Directory for SqlDirectory {
}

async fn rcpt(&self, address: &str) -> crate::Result<bool> {
match sqlx::query(&self.mappings.query_recipients)
let result = sqlx::query(&self.mappings.query_recipients)
.bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref())
.fetch_optional(&self.pool)
.await
{
.fetch(&self.pool)
.try_next()
.await;
match result {
Ok(Some(_)) => Ok(true),
Ok(None) if self.opt.catch_all => sqlx::query(&self.mappings.query_recipients)
.bind(to_catch_all_address(address))
.fetch_optional(&self.pool)
.fetch(&self.pool)
.try_next()
.await
.map(|id| id.is_some())
.map_err(Into::into),
Expand All @@ -126,15 +134,17 @@ impl Directory for SqlDirectory {
async fn vrfy(&self, address: &str) -> crate::Result<Vec<String>> {
sqlx::query_scalar::<_, String>(&self.mappings.query_verify)
.bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref())
.fetch_all(&self.pool)
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await
.map_err(Into::into)
}

async fn expn(&self, address: &str) -> crate::Result<Vec<String>> {
sqlx::query_scalar::<_, String>(&self.mappings.query_expand)
.bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref())
.fetch_all(&self.pool)
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await
.map_err(Into::into)
}
Expand All @@ -146,7 +156,8 @@ impl Directory for SqlDirectory {
q = q.bind(param);
}

q.fetch_optional(&self.pool)
q.fetch(&self.pool)
.try_next()
.await
.map(|r| r.is_some())
.map_err(Into::into)
Expand All @@ -155,7 +166,8 @@ impl Directory for SqlDirectory {
async fn is_local_domain(&self, domain: &str) -> crate::Result<bool> {
sqlx::query(&self.mappings.query_domains)
.bind(domain)
.fetch_optional(&self.pool)
.fetch(&self.pool)
.try_next()
.await
.map(|id| id.is_some())
.map_err(Into::into)
Expand Down
1 change: 1 addition & 0 deletions tests/src/directory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ const CONFIG: &str = r#"
[directory."sql"]
type = "sql"
address = "sqlite::memory:"
#address = "mysql://root:secret@localhost:3306/stalwart?ssl_mode=disabled"
[directory."sql".options]
catch-all = true
Expand Down

0 comments on commit 9725aa7

Please sign in to comment.