Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support API token for scanning hf:// #17682

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion crates/polars-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,17 @@ async = [
"polars-error/regex",
"polars-parquet?/async",
]
cloud = ["object_store", "async", "polars-error/object_store", "url", "serde_json", "serde", "file_cache", "reqwest"]
cloud = [
"object_store",
"async",
"polars-error/object_store",
"url",
"serde_json",
"serde",
"file_cache",
"reqwest",
"http",
]
file_cache = ["async", "dep:blake3", "dep:fs4"]
aws = ["object_store/aws", "cloud", "reqwest"]
azure = ["object_store/azure", "cloud"]
Expand Down
9 changes: 3 additions & 6 deletions crates/polars-io/src/cloud/object_store_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ pub async fn build_object_store(
}
}

#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
let options = options.map(std::borrow::Cow::Borrowed).unwrap_or_default();

let cloud_type = CloudType::from_url(&parsed)?;
Expand Down Expand Up @@ -111,16 +110,14 @@ pub async fn build_object_store(
allow_cache = false;
#[cfg(feature = "http")]
{
let store = object_store::http::HttpBuilder::new()
.with_url(url)
.with_client_options(super::get_client_options())
.build()?;
Ok::<_, PolarsError>(Arc::new(store) as Arc<dyn ObjectStore>)
let store = options.build_http(url)?;
PolarsResult::Ok(Arc::new(store) as Arc<dyn ObjectStore>)
}
}
#[cfg(not(feature = "http"))]
return err_missing_feature("http", &cloud_location.scheme);
},
CloudType::Hf => panic!("impl error: unresolved hf:// path"),
}?;
if allow_cache {
let mut cache = OBJECT_STORE_CACHE.write().await;
Expand Down
141 changes: 105 additions & 36 deletions crates/polars-io/src/cloud/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use polars_error::*;
use polars_utils::cache::FastFixedCache;
#[cfg(feature = "aws")]
use regex::Regex;
#[cfg(feature = "http")]
use reqwest::header::HeaderMap;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "aws")]
Expand Down Expand Up @@ -54,19 +56,27 @@ static BUCKET_REGION: Lazy<std::sync::Mutex<FastFixedCache<SmartString, SmartStr
#[allow(dead_code)]
type Configs<T> = Vec<(T, String)>;

#[derive(Clone, Debug, PartialEq, Hash, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub(crate) enum CloudConfig {
#[cfg(feature = "aws")]
Aws(Configs<AmazonS3ConfigKey>),
#[cfg(feature = "azure")]
Azure(Configs<AzureConfigKey>),
#[cfg(feature = "gcp")]
Gcp(Configs<GoogleConfigKey>),
#[cfg(feature = "http")]
Http { headers: Vec<(String, String)> },
}

#[derive(Clone, Debug, PartialEq, Hash, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
/// Options to connect to various cloud providers.
pub struct CloudOptions {
pub max_retries: usize,
#[cfg(feature = "file_cache")]
pub file_cache_ttl: u64,
#[cfg(feature = "aws")]
aws: Option<Configs<AmazonS3ConfigKey>>,
#[cfg(feature = "azure")]
azure: Option<Configs<AzureConfigKey>>,
#[cfg(feature = "gcp")]
gcp: Option<Configs<GoogleConfigKey>>,
pub(crate) config: Option<CloudConfig>,
}

impl Default for CloudOptions {
Expand All @@ -75,16 +85,29 @@ impl Default for CloudOptions {
max_retries: 2,
#[cfg(feature = "file_cache")]
file_cache_ttl: get_env_file_cache_ttl(),
#[cfg(feature = "aws")]
aws: Default::default(),
#[cfg(feature = "azure")]
azure: Default::default(),
#[cfg(feature = "gcp")]
gcp: Default::default(),
config: None,
}
}
}

#[cfg(feature = "http")]
pub(crate) fn try_build_http_header_map_from_items_slice<S: AsRef<str>>(
headers: &[(S, S)],
) -> PolarsResult<HeaderMap> {
use reqwest::header::{HeaderName, HeaderValue};

let mut map = HeaderMap::with_capacity(headers.len());
for (k, v) in headers {
let (k, v) = (k.as_ref(), v.as_ref());
map.insert(
HeaderName::from_str(k).map_err(to_compute_err)?,
HeaderValue::from_str(v).map_err(to_compute_err)?,
);
}

Ok(map)
}

#[allow(dead_code)]
/// Parse an untype configuration hashmap to a typed configuration for the given configuration key type.
fn parsed_untyped_config<T, I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
Expand Down Expand Up @@ -112,6 +135,7 @@ pub enum CloudType {
File,
Gcp,
Http,
Hf,
}

impl CloudType {
Expand All @@ -123,6 +147,7 @@ impl CloudType {
"gs" | "gcp" | "gcs" => Self::Gcp,
"file" => Self::File,
"http" | "https" => Self::Http,
"hf" => Self::Hf,
_ => polars_bail!(ComputeError: "unknown url scheme"),
})
}
Expand Down Expand Up @@ -225,21 +250,20 @@ impl CloudOptions {
mut self,
configs: I,
) -> Self {
self.aws = Some(
configs
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect::<Configs<AmazonS3ConfigKey>>(),
);
self.config = Some(CloudConfig::Aws(
configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
));
self
}

/// Build the [`object_store::ObjectStore`] implementation for AWS.
#[cfg(feature = "aws")]
pub async fn build_aws(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
let options = self.aws.as_ref();
let mut builder = AmazonS3Builder::from_env().with_url(url);
if let Some(options) = options {
if let Some(options) = &self.config {
let CloudConfig::Aws(options) = options else {
panic!("impl error: cloud type mismatch")
};
for (key, value) in options.iter() {
builder = builder.with_config(*key, value);
}
Expand Down Expand Up @@ -328,21 +352,20 @@ impl CloudOptions {
mut self,
configs: I,
) -> Self {
self.azure = Some(
configs
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect::<Configs<AzureConfigKey>>(),
);
self.config = Some(CloudConfig::Azure(
configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
));
self
}

/// Build the [`object_store::ObjectStore`] implementation for Azure.
#[cfg(feature = "azure")]
pub fn build_azure(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
let options = self.azure.as_ref();
let mut builder = MicrosoftAzureBuilder::from_env();
if let Some(options) = options {
if let Some(options) = &self.config {
let CloudConfig::Azure(options) = options else {
panic!("impl error: cloud type mismatch")
};
for (key, value) in options.iter() {
builder = builder.with_config(*key, value);
}
Expand All @@ -362,21 +385,20 @@ impl CloudOptions {
mut self,
configs: I,
) -> Self {
self.gcp = Some(
configs
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect::<Configs<GoogleConfigKey>>(),
);
self.config = Some(CloudConfig::Gcp(
configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
));
self
}

/// Build the [`object_store::ObjectStore`] implementation for GCP.
#[cfg(feature = "gcp")]
pub fn build_gcp(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
let options = self.gcp.as_ref();
let mut builder = GoogleCloudStorageBuilder::from_env();
if let Some(options) = options {
if let Some(options) = &self.config {
let CloudConfig::Gcp(options) = options else {
panic!("impl error: cloud type mismatch")
};
for (key, value) in options.iter() {
builder = builder.with_config(*key, value);
}
Expand All @@ -390,6 +412,23 @@ impl CloudOptions {
.map_err(to_compute_err)
}

#[cfg(feature = "http")]
pub fn build_http(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
object_store::http::HttpBuilder::new()
.with_url(url)
.with_client_options({
let mut opts = super::get_client_options();
if let Some(CloudConfig::Http { headers }) = &self.config {
opts = opts.with_default_headers(try_build_http_header_map_from_items_slice(
headers.as_slice(),
)?);
}
opts
})
.build()
.map_err(to_compute_err)
}

/// Parse a configuration from a Hashmap. This is the interface from Python.
#[allow(unused_variables)]
pub fn from_untyped_config<I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
Expand Down Expand Up @@ -432,6 +471,36 @@ impl CloudOptions {
polars_bail!(ComputeError: "'gcp' feature is not enabled");
}
},
CloudType::Hf => {
#[cfg(feature = "http")]
{
let mut this = Self::default();

if let Ok(v) = std::env::var("HF_TOKEN") {
this.config = Some(CloudConfig::Http {
headers: vec![("Authorization".into(), format!("Bearer {}", v))],
})
}

for (i, (k, v)) in config.into_iter().enumerate() {
let (k, v) = (k.as_ref(), v.into());

if i == 0 && k == "token" {
this.config = Some(CloudConfig::Http {
headers: vec![("Authorization".into(), format!("Bearer {}", v))],
})
} else {
polars_bail!(ComputeError: "unknown configuration key: {}", k)
}
}

Ok(this)
}
#[cfg(not(feature = "http"))]
{
polars_bail!(ComputeError: "'http' feature is not enabled");
}
},
}
}
}
Expand Down
48 changes: 31 additions & 17 deletions crates/polars-io/src/path_utils/hugging_face.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
use std::collections::VecDeque;
use std::path::PathBuf;

use polars_error::{polars_bail, to_compute_err, PolarsResult};
use polars_error::{polars_bail, polars_err, to_compute_err, PolarsResult};

use crate::cloud::{extract_prefix_expansion, Matcher};
use crate::cloud::{
extract_prefix_expansion, try_build_http_header_map_from_items_slice, CloudConfig,
CloudOptions, Matcher,
};
use crate::path_utils::HiveIdxTracker;
use crate::pl_async::with_concurrency_budget;

Expand Down Expand Up @@ -198,14 +201,25 @@ impl<'a> GetPages<'a> {
pub(super) async fn expand_paths_hf(
paths: &[PathBuf],
check_directory_level: bool,
cloud_options: Option<&CloudOptions>,
) -> PolarsResult<(usize, Vec<PathBuf>)> {
assert!(!paths.is_empty());

let client = &reqwest::ClientBuilder::new()
.http1_only()
.https_only(true)
.build()
.unwrap();
let client = reqwest::ClientBuilder::new().http1_only().https_only(true);

let client = if let Some(CloudOptions {
config: Some(CloudConfig::Http { headers }),
..
}) = cloud_options
{
client.default_headers(try_build_http_header_map_from_items_slice(
headers.as_slice(),
)?)
} else {
client
};

let client = &client.build().unwrap();

let mut out_paths = vec![];
let mut stack = VecDeque::new();
Expand Down Expand Up @@ -263,26 +277,26 @@ pub(super) async fn expand_paths_hf(
client,
};

fn try_parse_api_response(bytes: &[u8]) -> PolarsResult<Vec<HFAPIResponse>> {
serde_json::from_slice::<Vec<HFAPIResponse>>(bytes).map_err(
|e| polars_err!(ComputeError: "failed to parse API response as JSON: error: {}, value: {}", e, std::str::from_utf8(bytes).unwrap()),
)
}

if let Some(matcher) = expansion_matcher {
while let Some(bytes) = gp.next().await {
let bytes = bytes?;
let bytes = bytes.as_ref();
entries.extend(
serde_json::from_slice::<Vec<HFAPIResponse>>(bytes)
.map_err(to_compute_err)?
.into_iter()
.filter(|x| {
matcher.is_matching(x.path.as_str()) && (!x.is_file() || x.size > 0)
}),
);
entries.extend(try_parse_api_response(bytes)?.into_iter().filter(|x| {
matcher.is_matching(x.path.as_str()) && (!x.is_file() || x.size > 0)
}));
}
} else {
while let Some(bytes) = gp.next().await {
let bytes = bytes?;
let bytes = bytes.as_ref();
entries.extend(
serde_json::from_slice::<Vec<HFAPIResponse>>(bytes)
.map_err(to_compute_err)?
try_parse_api_response(bytes)?
.into_iter()
.filter(|x| !x.is_file() || x.size > 0),
);
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/path_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ pub fn expand_paths_hive(
if first_path.starts_with("hf://") {
let (expand_start_idx, paths) =
crate::pl_async::get_runtime().block_on_potential_spawn(
hugging_face::expand_paths_hf(paths, check_directory_level),
hugging_face::expand_paths_hf(paths, check_directory_level, cloud_options),
)?;

return Ok((Arc::from(paths), expand_start_idx));
Expand Down
Loading