diff --git a/Cargo.lock b/Cargo.lock index 96f77ce..9d93c69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1407,7 +1407,7 @@ dependencies = [ [[package]] name = "object_store_ffi" -version = "0.10.1" +version = "0.11.0" dependencies = [ "anyhow", "async-channel", @@ -1422,6 +1422,7 @@ dependencies = [ "flate2", "flume", "futures-util", + "hickory-resolver", "hyper", "metrics", "metrics-util", diff --git a/Cargo.toml b/Cargo.toml index bd36be9..4af6651 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "object_store_ffi" -version = "0.10.1" +version = "0.11.0" edition = "2021" [[bench]] @@ -44,6 +44,7 @@ reqwest = { version = "0.12", default-features = false, features = ["rustls-tls" # object_store = { version = "0.10.1", features = ["azure", "aws"] } # Pinned to a specific commit while waiting for upstream object_store = { git = "https://github.com/andrebsguedes/arrow-rs.git", tag = "v0.10.2-beta1", features = ["azure", "aws", "experimental-azure-list-offset", "experimental-arbitrary-list-prefix"] } +hickory-resolver = "0.24" thiserror = "1" anyhow = { version = "1", features = ["backtrace"] } once_cell = "1.18" diff --git a/src/lib.rs b/src/lib.rs index 5f465c9..fe87f7b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -433,9 +433,9 @@ impl Client { let (store, crypto_material_provider, stage_prefix, extension) = build_store_for_snowflake_stage(map, config.retry_config.clone()).await?; let prefix = match (stage_prefix, config.prefix) { - (s, Some(u)) if s.ends_with("/") => Some(format!("{s}{u}")), - (s, Some(u)) => Some(format!("{s}/{u}")), - (s, None) => Some(s) + (Some(s), Some(u)) if s.ends_with("/") => Some(format!("{s}{u}")), + (Some(s), Some(u)) => Some(format!("{s}/{u}")), + (s, u) => s.or(u) }; config.prefix = prefix; diff --git a/src/snowflake/client.rs b/src/snowflake/client.rs index 7c470d6..38d4db6 100644 --- a/src/snowflake/client.rs +++ b/src/snowflake/client.rs @@ -7,7 +7,7 @@ use zeroize::Zeroize; use moka::future::Cache; use crate::{duration_on_drop, error::{Error, RetryState}, metrics}; use crate::util::{deserialize_str, deserialize_slice}; -// use anyhow::anyhow; +use super::resolver::HickoryResolverWithEdns; #[derive(Debug, Serialize, Deserialize)] @@ -74,12 +74,42 @@ pub(crate) struct SnowflakeQueryData { #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub(crate) struct SnowflakeStageCreds { +pub(crate) struct SnowflakeStageAwsCreds { pub aws_key_id: String, pub aws_secret_key: String, pub aws_token: String, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub(crate) struct SnowflakeStageAzureCreds { + pub azure_sas_token: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub(crate) enum SnowflakeStageCreds { + Aws(SnowflakeStageAwsCreds), + Azure(SnowflakeStageAzureCreds), +} + +impl SnowflakeStageCreds { + pub(crate) fn as_aws(&self) -> crate::Result<&SnowflakeStageAwsCreds> { + match self { + SnowflakeStageCreds::Aws(creds) => Ok(creds), + SnowflakeStageCreds::Azure(_) => Err(Error::invalid_response("Expected AWS credentials, but got Azure ones")), + } + } + + pub(crate) fn as_azure(&self) -> crate::Result<&SnowflakeStageAzureCreds> { + match self { + SnowflakeStageCreds::Azure(creds) => Ok(creds), + SnowflakeStageCreds::Aws(_) => Err(Error::invalid_response("Expected Azure credentials, but got AWS ones")), + } + } +} + #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub(crate) struct SnowflakeStageInfo { @@ -118,6 +148,7 @@ pub(crate) enum NormalizedStageInfo { storage_account: String, container: String, prefix: String, + azure_sas_token: String, #[serde(skip_serializing_if = "Option::is_none")] end_point: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -132,18 +163,34 @@ impl TryFrom<&SnowflakeStageInfo> for NormalizedStageInfo { if value.location_type == "S3" { let (bucket, prefix) = value.location.split_once('/') .ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the bucket name"))?; + let creds = value.creds.as_aws()?; return Ok(NormalizedStageInfo::S3 { bucket: bucket.to_string(), prefix: prefix.to_string(), region: value.region.clone(), - aws_key_id: value.creds.aws_key_id.clone(), - aws_secret_key: value.creds.aws_secret_key.clone(), - aws_token: value.creds.aws_token.clone(), + aws_key_id: creds.aws_key_id.clone(), + aws_secret_key: creds.aws_secret_key.clone(), + aws_token: creds.aws_token.clone(), + end_point: value.end_point.clone(), + test_endpoint: value.test_endpoint.clone() + }) + } else if value.location_type == "AZURE" { + let (container, prefix) = value.location.split_once('/') + .ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the container name"))?; + let creds = value.creds.as_azure()?; + let storage_account = value.storage_account + .clone() + .ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the storage account name"))?; + return Ok(NormalizedStageInfo::BlobStorage { + storage_account: storage_account, + container: container.to_string(), + prefix: prefix.to_string(), + azure_sas_token: creds.azure_sas_token.clone(), end_point: value.end_point.clone(), test_endpoint: value.test_endpoint.clone() }) } else { - return Err(Error::not_implemented("Azure BlobStorage is not implemented")); + return Err(Error::not_implemented(format!("Location type {} is not implemented", value.location_type))); } } } @@ -297,6 +344,7 @@ impl SnowflakeClient { let client = SnowflakeClient { config, client: reqwest::Client::builder() + .dns_resolver(Arc::new(HickoryResolverWithEdns::default())) .timeout(Duration::from_secs(180)) .build().unwrap(), token: Arc::new(Mutex::new(None)), diff --git a/src/snowflake/kms.rs b/src/snowflake/kms.rs index 789791b..adfaa51 100644 --- a/src/snowflake/kms.rs +++ b/src/snowflake/kms.rs @@ -1,7 +1,7 @@ -use crate::{duration_on_drop, encryption::{ContentCryptoMaterial, CryptoMaterialProvider, CryptoScheme, EncryptedKey, Iv, Key}, error::{Error, ErrorExt}, metrics, snowflake::SnowflakeClient, util::deserialize_str}; +use crate::{duration_on_drop, encryption::{ContentCryptoMaterial, CryptoMaterialProvider, CryptoScheme, EncryptedKey, Iv, Key}, error::{Error, ErrorExt}, metrics, snowflake::SnowflakeClient, util::{deserialize_str, required_attribute}}; +use ::metrics::counter; use crate::error::Kind as ErrorKind; -use ::metrics::counter; use serde::{Serialize, Deserialize}; use object_store::{Attributes, Attribute, AttributeValue}; use anyhow::Context; @@ -37,7 +37,7 @@ impl Default for SnowflakeStageKmsConfig { } #[derive(Clone)] -pub(crate) struct SnowflakeStageKms { +pub(crate) struct SnowflakeStageS3Kms { client: Arc, stage: String, prefix: String, @@ -45,9 +45,9 @@ pub(crate) struct SnowflakeStageKms { keyring: Cache } -impl std::fmt::Debug for SnowflakeStageKms { +impl std::fmt::Debug for SnowflakeStageS3Kms { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SnowflakeStageKms") + f.debug_struct("SnowflakeStageS3Kms") .field("client", &self.client) .field("stage", &self.stage) .field("config", &self.config) @@ -56,14 +56,14 @@ impl std::fmt::Debug for SnowflakeStageKms { } } -impl SnowflakeStageKms { +impl SnowflakeStageS3Kms { pub(crate) fn new( client: Arc, stage: impl Into, prefix: impl Into, config: SnowflakeStageKmsConfig - ) -> SnowflakeStageKms { - SnowflakeStageKms { + ) -> SnowflakeStageS3Kms { + SnowflakeStageS3Kms { client, stage: stage.into(), prefix: prefix.into(), @@ -77,7 +77,7 @@ impl SnowflakeStageKms { } #[async_trait::async_trait] -impl CryptoMaterialProvider for SnowflakeStageKms { +impl CryptoMaterialProvider for SnowflakeStageS3Kms { async fn material_for_write(&self, _path: &str, data_len: Option) -> crate::Result<(ContentCryptoMaterial, Attributes)> { let _guard = duration_on_drop!(metrics::material_for_write_duration); let info = self.client.current_upload_info(&self.stage).await?; @@ -137,40 +137,26 @@ impl CryptoMaterialProvider for SnowflakeStageKms { async fn material_from_metadata(&self, path: &str, attr: &Attributes) -> crate::Result { let _guard = duration_on_drop!(metrics::material_from_metadata_duration); let path = path.strip_prefix(&self.prefix).unwrap_or(path); - let required_attribute = |key: &'static str| { - let v: &str = attr.get(&Attribute::Metadata(key.into())) - .ok_or_else(|| Error::required_config(format!("missing required attribute `{}`", key)))? - .as_ref(); - Ok::<_, Error>(v) - }; - - let material_description: MaterialDescription = deserialize_str(required_attribute("x-amz-matdesc")?) + let material_description: MaterialDescription = + deserialize_str(required_attribute(&attr, "x-amz-matdesc")?) .map_err(Error::deserialize_response_err("failed to deserialize matdesc"))?; - let master_key = self.keyring.try_get_with(material_description.query_id, async { - let info = self.client.fetch_path_info(&self.stage, path).await?; - let position = info.src_locations.iter().position(|l| l == path) - .ok_or_else(|| Error::invalid_response("path not found"))?; - let encryption_material = info.encryption_material.get(position) - .cloned() - .ok_or_else(|| Error::invalid_response("src locations and encryption material length mismatch"))? - .ok_or_else(|| Error::invalid_response("path not encrypted"))?; - - let master_key = Key::from_base64(&encryption_material.query_stage_master_key) - .map_err(ErrorKind::MaterialDecode)?; - counter!(metrics::total_keyring_miss).increment(1); - Ok::<_, Error>(master_key) - }).await?; - counter!(metrics::total_keyring_get).increment(1); - - let cek = EncryptedKey::from_base64(required_attribute("x-amz-key")?) + let master_key = get_master_key( + &self.client, + material_description.query_id.clone(), + path, + &self.stage, + &self.keyring, + ).await?; + + let cek = EncryptedKey::from_base64(required_attribute(&attr, "x-amz-key")?) .map_err(ErrorKind::MaterialDecode)?; let cek = cek.decrypt_aes_128_ecb(&master_key) .map_err(ErrorKind::MaterialCrypt)?; - let iv = Iv::from_base64(required_attribute("x-amz-iv")?) + let iv = Iv::from_base64(required_attribute(&attr, "x-amz-iv")?) .map_err(ErrorKind::MaterialDecode)?; - let alg = required_attribute("x-amz-cek-alg"); + let alg = required_attribute(&attr, "x-amz-cek-alg"); let scheme = match alg { Ok("AES/GCM/NoPadding") => CryptoScheme::Aes256Gcm, @@ -193,3 +179,211 @@ impl CryptoMaterialProvider for SnowflakeStageKms { Ok(content_material) } } + +#[derive(Clone)] +pub(crate) struct SnowflakeStageAzureKms { + client: Arc, + stage: String, + prefix: String, + config: SnowflakeStageKmsConfig, + keyring: Cache, +} + +impl std::fmt::Debug for SnowflakeStageAzureKms { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SnowflakeStageAzureKms") + .field("client", &self.client) + .field("stage", &self.stage) + .field("config", &self.config) + .field("keyring", &"redacted") + .finish() + } +} + +impl SnowflakeStageAzureKms { + pub(crate) fn new( + client: Arc, + stage: impl Into, + prefix: impl Into, + config: SnowflakeStageKmsConfig + ) -> SnowflakeStageAzureKms { + SnowflakeStageAzureKms { + client, + stage: stage.into(), + prefix: prefix.into(), + keyring: Cache::builder() + .max_capacity(config.keyring_capacity as u64) + .time_to_live(config.keyring_ttl) + .build(), + config + } + } +} + +const AZURE_MATDESC_KEY: &str = "matdesc"; +const AZURE_ENCDATA_KEY: &str = "encryptiondata"; + +#[async_trait::async_trait] +impl CryptoMaterialProvider for SnowflakeStageAzureKms { + async fn material_for_write(&self, _path: &str, _data_len: Option) -> crate::Result<(ContentCryptoMaterial, Attributes)> { + let _guard = duration_on_drop!(metrics::material_for_write_duration); + let info = self.client.current_upload_info(&self.stage).await?; + + let encryption_material = info.encryption_material.as_ref() + .ok_or_else(|| ErrorKind::StorageNotEncrypted(self.stage.clone()))?; + + let description = MaterialDescription { + smk_id: encryption_material.smk_id.to_string(), + query_id: encryption_material.query_id.clone(), + key_size: "128".to_string() + }; + let master_key = Key::from_base64(&encryption_material.query_stage_master_key) + .map_err(ErrorKind::MaterialDecode)?; + + let scheme = self.config.crypto_scheme; + let material = ContentCryptoMaterial::generate(scheme); + let encrypted_cek = material.cek.clone().encrypt_aes_128_ecb(&master_key) + .map_err(ErrorKind::MaterialCrypt)?; + + let mut attributes = Attributes::new(); + + // We hardcode most of these values as the Go Snowflake client does (see + // https://github.com/snowflakedb/gosnowflake/blob/099708d318689634a558f705ccc19b3b7b278972/azure_storage_client.go#L152) + let encryption_data = EncryptionData { + encryption_mode: "FullBlob".to_string(), + wrapped_content_key: WrappedContentKey { + key_id: "symmKey1".to_string(), + encrypted_key: encrypted_cek.as_base64(), + algorithm: "AES_CBC_256".to_string(), + }, + encryption_agent: EncryptionAgent { + protocol: "1.0".to_string(), + encryption_algorithm: "AES_CBC_128".to_string(), + }, + content_encryption_i_v: material.iv.as_base64(), + key_wrapping_metadata: KeyWrappingMetadata { + encryption_library: "Java 5.3.0".to_string(), + }, + }; + + attributes.insert( + Attribute::Metadata(AZURE_ENCDATA_KEY.into()), + AttributeValue::from( + serde_json::to_string(&encryption_data) + .context("failed to encode encryption data") + .to_err()? + ) + ); + + attributes.insert( + Attribute::Metadata(AZURE_MATDESC_KEY.into()), + AttributeValue::from( + serde_json::to_string(&description) + .context("failed to encode matdesc") + .to_err()? + ) + ); + + Ok((material, attributes)) + } + + async fn material_from_metadata(&self, path: &str, attr: &Attributes) -> crate::Result { + let _guard = duration_on_drop!(metrics::material_from_metadata_duration); + let path = path.strip_prefix(&self.prefix).unwrap_or(path); + + let material_description: MaterialDescription = + deserialize_str(required_attribute(&attr, AZURE_MATDESC_KEY)?) + .map_err(Error::deserialize_response_err("failed to deserialize matdesc"))?; + + let master_key = get_master_key( + &self.client, + material_description.query_id.clone(), + path, + &self.stage, + &self.keyring, + ).await?; + + let encryption_data: EncryptionData = + deserialize_str(required_attribute(&attr, AZURE_ENCDATA_KEY)?) + .map_err(Error::deserialize_response_err("failed to deserialize encryption data"))?; + + let cek = EncryptedKey::from_base64(&encryption_data.wrapped_content_key.encrypted_key) + .map_err(ErrorKind::MaterialDecode)?; + let cek = cek.decrypt_aes_128_ecb(&master_key) + .map_err(ErrorKind::MaterialCrypt)?; + let iv = Iv::from_base64(&encryption_data.content_encryption_i_v) + .map_err(ErrorKind::MaterialDecode)?; + + let scheme = match encryption_data.encryption_agent.encryption_algorithm.as_str() { + "AES_CBC_128" => CryptoScheme::Aes128Cbc, + "AES_CBC_256" => CryptoScheme::Aes128Cbc, + v => unimplemented!("encryption algorithm `{}` not implemented", v) + }; + + let content_material = ContentCryptoMaterial { + scheme, + cek, + iv, + aad: None, + }; + + Ok(content_material) + } +} + + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct EncryptionData { + encryption_mode: String, + wrapped_content_key: WrappedContentKey, + content_encryption_i_v: String, + encryption_agent: EncryptionAgent, + key_wrapping_metadata: KeyWrappingMetadata, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct WrappedContentKey { + key_id: String, + encrypted_key: String, + algorithm: String, // alg for encrypting the key +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct EncryptionAgent { + protocol: String, + encryption_algorithm: String, // alg for encryption the content +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct KeyWrappingMetadata { + encryption_library: String, +} + +async fn get_master_key( + client: &SnowflakeClient, + query_id: String, + path: &str, + stage: &str, + keyring: &Cache, +) -> crate::Result { + let master_key = keyring.try_get_with(query_id, async { + let info = client.fetch_path_info(stage, path).await?; + let position = info.src_locations.iter().position(|l| l == path) + .ok_or_else(|| Error::invalid_response("path not found"))?; + let encryption_material = info.encryption_material.get(position) + .cloned() + .ok_or_else(|| Error::invalid_response("src locations and encryption material length mismatch"))? + .ok_or_else(|| Error::invalid_response("path not encrypted"))?; + + let master_key = Key::from_base64(&encryption_material.query_stage_master_key) + .map_err(ErrorKind::MaterialDecode)?; + counter!(metrics::total_keyring_miss).increment(1); + Ok::<_, Error>(master_key) + }).await?; + counter!(metrics::total_keyring_get).increment(1); + Ok(master_key) +} diff --git a/src/snowflake/mod.rs b/src/snowflake/mod.rs index 66e6c8c..50e3d60 100644 --- a/src/snowflake/mod.rs +++ b/src/snowflake/mod.rs @@ -6,9 +6,11 @@ use anyhow::Context as AnyhowContext; use client::{NormalizedStageInfo, SnowflakeClient, SnowflakeClientConfig}; pub(crate) mod kms; -use kms::{SnowflakeStageKms, SnowflakeStageKmsConfig}; +use kms::{SnowflakeStageS3Kms, SnowflakeStageAzureKms, SnowflakeStageKmsConfig}; -use object_store::{RetryConfig, ObjectStore}; +mod resolver; + +use object_store::{azure::AzureCredential, RetryConfig, ObjectStore}; use tokio::sync::Mutex; use std::sync::Arc; @@ -64,24 +66,37 @@ impl object_store::CredentialProvider for S3StageCredentialProvider { } })?; + if info.stage_info.location_type != "S3" { + return Err(object_store::Error::Generic { + store: "S3", + source: Error::invalid_response("Location type must be S3 for this provider").into() + }) + } + + let new_creds = info.stage_info.creds.as_aws() + .map_err(|e| object_store::Error::Generic { + store: "S3", + source: e.into() + })?; + let mut locked = self.cached.lock().await; match locked.as_ref() { - Some(creds) => if creds.key_id == info.stage_info.creds.aws_key_id { + Some(creds) => if creds.key_id == new_creds.aws_key_id { return Ok(Arc::clone(creds)); } _ => {} } // The session token is empty when testing against minio - let token = match info.stage_info.creds.aws_token.trim() { + let token = match new_creds.aws_token.trim() { "" => None, token => Some(token.to_string()) }; let creds = Arc::new(object_store::aws::AwsCredential { - key_id: info.stage_info.creds.aws_key_id.clone(), - secret_key: info.stage_info.creds.aws_secret_key.clone(), + key_id: new_creds.aws_key_id.clone(), + secret_key: new_creds.aws_secret_key.clone(), token }); @@ -92,6 +107,93 @@ impl object_store::CredentialProvider for S3StageCredentialProvider { } +#[derive(Debug)] +pub(crate) struct SnowflakeAzureExtension { + stage: String, + client: Arc, +} + +#[async_trait::async_trait] +impl Extension for SnowflakeAzureExtension { + fn as_any(&self) -> &dyn std::any::Any { + self + } + async fn current_stage_info(&self) -> crate::Result { + let stage_info = &self + .client + .current_upload_info(&self.stage) + .await? + .stage_info; + let stage_info: NormalizedStageInfo = stage_info.try_into()?; + let string = serde_json::to_string(&stage_info) + .context("failed to encode stage_info as json").to_err()?; + Ok(string) + } +} + +#[derive(Debug)] +pub(crate) struct AzureStageCredentialProvider { + stage: String, + client: Arc, + cached: Mutex>> +} + +impl AzureStageCredentialProvider { + pub(crate) fn new(stage: impl AsRef, client: Arc) -> AzureStageCredentialProvider { + AzureStageCredentialProvider { stage: stage.as_ref().to_string(), client, cached: Mutex::new(None) } + } +} + +#[async_trait::async_trait] +impl object_store::CredentialProvider for AzureStageCredentialProvider { + type Credential = object_store::azure::AzureCredential; + async fn get_credential(&self) -> object_store::Result> { + let info = self.client.current_upload_info(&self.stage).await + .map_err(|e| { + object_store::Error::Generic { + store: "MicrosoftAzure", + source: e.into() + } + })?; + + if info.stage_info.location_type != "AZURE" { + return Err(object_store::Error::Generic { + store: "MicrosoftAzure", + source: Error::invalid_response("Location type must be AZURE for this provider").into() + }) + } + + let new_creds = info.stage_info.creds.as_azure() + .map_err(|e| object_store::Error::Generic { + store: "MicrosoftAzure", + source: e.into() + })?; + + + let token_bytes = new_creds.azure_sas_token.trim_start_matches('?').as_bytes(); + let new_pairs = url::form_urlencoded::parse(token_bytes) + .into_owned() + .collect(); + + let mut locked = self.cached.lock().await; + + match locked.as_ref() { + Some(creds) => { + if matches!(creds.as_ref(), AzureCredential::SASToken(pairs) if *pairs == new_pairs) { + return Ok(Arc::clone(creds)); + } + } + _ => {} + } + + let creds = Arc::new(AzureCredential::SASToken(new_pairs)); + + *locked = Some(Arc::clone(&creds)); + + Ok(creds) + } +} + #[repr(C)] pub struct StageInfoResponse { result: CResult, @@ -235,7 +337,7 @@ pub(crate) async fn build_store_for_snowflake_stage( ) -> crate::Result<( Arc, Option>, - String, + Option, ClientExtension )> { let config = validate_config_for_snowflake(&mut config_map, retry_config.clone())?; @@ -290,7 +392,7 @@ pub(crate) async fn build_store_for_snowflake_stage( let crypto_material_provider = if info.stage_info.is_client_side_encrypted { let kms_config = config.kms_config.unwrap_or_default(); - let stage_kms = SnowflakeStageKms::new(client.clone(), &config.stage, stage_prefix, kms_config); + let stage_kms = SnowflakeStageS3Kms::new(client.clone(), &config.stage, stage_prefix, kms_config); Some::>(Arc::new(stage_kms)) } else { None @@ -301,10 +403,71 @@ pub(crate) async fn build_store_for_snowflake_stage( client }); - Ok((Arc::new(store), crypto_material_provider, stage_prefix.to_string(), extension)) + Ok((Arc::new(store), crypto_material_provider, Some(stage_prefix.to_string()), extension)) + } + "AZURE" => { + let (container, stage_prefix) = info.stage_info.location.split_once('/') + .ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the container name"))?; + let storage_account = info.stage_info.storage_account + .clone() + .ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the storage account name"))?; + + let provider = AzureStageCredentialProvider::new(&config.stage, client.clone()); + + let mut builder = object_store::azure::MicrosoftAzureBuilder::default() + .with_account(storage_account) + .with_container_name(container) + .with_credentials(Arc::new(provider)) + .with_retry(retry_config); + + if let Some(test_endpoint) = &info.stage_info.test_endpoint { + builder = builder.with_endpoint(test_endpoint.to_string()); + let mut azurite_host = url::Url::parse(&test_endpoint) + .map_err(Error::invalid_config_err("failed to parse azurite_host"))?; + azurite_host.set_path(""); + unsafe { std::env::set_var("AZURITE_BLOB_STORAGE_URL", azurite_host.as_str()) }; + config_map.insert("allow_invalid_certificates".into(), "true".into()); + config_map.insert("azure_storage_use_emulator".into(), "true".into()); + } + + for (key, value) in config_map { + builder = builder.with_config(key.parse()?, value); + } + + let store = builder.build()?; + + if config.kms_config.is_some() && !info.stage_info.is_client_side_encrypted { + return Err(ErrorKind::StorageNotEncrypted(config.stage.clone()).into()); + } + + let crypto_material_provider = if info.stage_info.is_client_side_encrypted { + let kms_config = config.kms_config.unwrap_or_default(); + let stage_kms = SnowflakeStageAzureKms::new( + client.clone(), + &config.stage, + stage_prefix, + kms_config, + ); + Some::>(Arc::new(stage_kms)) + } else { + None + }; + + let extension = Arc::new(SnowflakeAzureExtension { + stage: config.stage.clone(), + client + }); + + let stage_prefix = if stage_prefix.is_empty() { + None + } else { + Some(stage_prefix.to_string()) + }; + + Ok((Arc::new(store), crypto_material_provider, stage_prefix, extension)) } _ => { - unimplemented!("unknown stage location type"); + unimplemented!("unknown stage location type: {}", info.stage_info.location_type); } } } diff --git a/src/snowflake/resolver.rs b/src/snowflake/resolver.rs new file mode 100644 index 0000000..fccc360 --- /dev/null +++ b/src/snowflake/resolver.rs @@ -0,0 +1,47 @@ +use hickory_resolver::{TokioAsyncResolver, system_conf}; +use reqwest::dns::{Addrs, Name, Resolving, Resolve}; +use std::sync::Arc; +use std::net::SocketAddr; +use once_cell::sync::OnceCell; + +/// A hickory resolver that uses extended DNS (eDNS) to resolve domain names. We use this to +/// circumvent a bug in the hickory resolver: hickory allocates a buffer of 512 bytes for +/// name server replies, but we observed >512 bytes replies in Azure. Enabling eDNS +/// circumvents this problem because hickory determines the receive buffer size differently +/// with eDNS. Unfortunately, we have not figured out an easier way to enable eDNS than +/// implementing a custom resolver. Our implementation is based on reqwest's hickory +/// wrapper, see https://github.com/Xuanwo/reqwest-hickory-resolver/blob/main/src/lib.rs. +#[derive(Debug, Default, Clone)] +pub(super) struct HickoryResolverWithEdns { + // Delay construction as initialization might be outside the Tokio runtime context. + state: Arc>, +} + +impl Resolve for HickoryResolverWithEdns { + fn resolve(&self, name: Name) -> Resolving { + let hickory_resolver = self.clone(); + Box::pin(async move { + let resolver = hickory_resolver.state.get_or_try_init(new_resolver)?; + + let lookup = resolver.lookup_ip(name.as_str()).await?; + + let addrs: Addrs = Box::new(lookup.into_iter().map(|addr| SocketAddr::new(addr, 0))); + Ok(addrs) + }) + } +} + +/// Create a new resolver with the default configuration, +/// which reads from `/etc/resolve.conf`. +fn new_resolver() -> std::io::Result { + let (config, mut opts) = system_conf::read_system_conf().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::Other, + format!("error reading DNS system conf: {}", e), + ) + })?; + + opts.edns0 = true; + + Ok(TokioAsyncResolver::tokio(config, opts)) +} \ No newline at end of file diff --git a/src/util.rs b/src/util.rs index c688243..ef4fee2 100644 --- a/src/util.rs +++ b/src/util.rs @@ -8,6 +8,7 @@ use object_store::path::Path; use object_store::{Attribute, AttributeValue, Attributes, GetOptions, ObjectStore, TagSet}; use pin_project::pin_project; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, AsyncWriteExt}; +use crate::error::Error; use crate::error::Kind as ErrorKind; use std::error::Error as StdError; @@ -485,3 +486,10 @@ where let de = &mut serde_json::Deserializer::from_str(v); serde_path_to_error::deserialize(de) } + +pub(crate) fn required_attribute<'a>(attr: &'a Attributes, key: &'static str) -> Result<&'a str, Error> { + let v: &str = attr.get(&Attribute::Metadata(key.into())) + .ok_or_else(|| Error::required_config(format!("missing required attribute `{}`", key)))? + .as_ref(); + Ok::<_, Error>(v) +}