Skip to content

Commit

Permalink
refactor: Update Box<dyn Error> to add Send + Sync
Browse files Browse the repository at this point in the history
  • Loading branch information
peasee committed Sep 19, 2024
1 parent aab717f commit 2819dd9
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
23 changes: 12 additions & 11 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::{
config::{SPICE_CLOUD_FIRECACHE_ADDR, SPICE_CLOUD_FLIGHT_ADDR, SPICE_LOCAL_FLIGHT_ADDR},
config::{
GenericError, SPICE_CLOUD_FIRECACHE_ADDR, SPICE_CLOUD_FLIGHT_ADDR, SPICE_LOCAL_FLIGHT_ADDR,
},
flight::SqlFlightClient,
tls::new_tls_flight_channel,
};
use arrow_flight::decode::FlightRecordBatchStream;
use futures::try_join;
use std::error::Error;
use tonic::transport::Channel;

struct SpiceClientConfig {
Expand All @@ -21,7 +22,7 @@ impl SpiceClientConfig {
}
}

pub async fn load_from_default() -> Result<SpiceClientConfig, Box<dyn Error>> {
pub async fn load_from_default() -> Result<SpiceClientConfig, GenericError> {
let (flight_chan, firecache_chan) = try_join!(
new_tls_flight_channel(SPICE_CLOUD_FLIGHT_ADDR),
new_tls_flight_channel(SPICE_CLOUD_FIRECACHE_ADDR)
Expand Down Expand Up @@ -52,8 +53,8 @@ impl SpiceClient {
///
/// ## Errors
///
/// - `Box<dyn Error>` for any query error
pub async fn new(api_key: &str) -> Result<Self, Box<dyn Error>> {
/// - `Box<dyn Error + Send + Sync>` for any query error
pub async fn new(api_key: &str) -> Result<Self, GenericError> {
let config = SpiceClientConfig::load_from_default().await?;

Ok(Self {
Expand All @@ -80,8 +81,8 @@ impl SpiceClient {
///
/// ## Errors
///
/// - `Box<dyn Error>` for any query error
pub async fn query(&mut self, query: &str) -> Result<FlightRecordBatchStream, Box<dyn Error>> {
/// - `Box<dyn Error + Send + Sync>` for any query error
pub async fn query(&mut self, query: &str) -> Result<FlightRecordBatchStream, GenericError> {
self.flight.query(query).await
}

Expand All @@ -98,11 +99,11 @@ impl SpiceClient {
///
/// ## Errors
///
/// - `Box<dyn Error>` for any query error
/// - `Box<dyn Error + Send + Sync>` for any query error
pub async fn fire_query(
&mut self,
query: &str,
) -> Result<FlightRecordBatchStream, Box<dyn Error>> {
) -> Result<FlightRecordBatchStream, GenericError> {
self.firecache.query(query).await
}
}
Expand Down Expand Up @@ -194,8 +195,8 @@ impl SpiceClientBuilder {
///
/// ## Errors
///
/// - `Box<dyn Error>` if flight or firecache channel creation fails
pub async fn build(self) -> Result<SpiceClient, Box<dyn Error>> {
/// - `Box<dyn Error + Send + Sync>` if flight or firecache channel creation fails
pub async fn build(self) -> Result<SpiceClient, GenericError> {
let flight_channel = match self.flight_url {
Some(url) => new_tls_flight_channel(&url).await?,
None => new_tls_flight_channel(SPICE_LOCAL_FLIGHT_ADDR).await?,
Expand Down
6 changes: 4 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ pub const SPICE_CLOUD_FIRECACHE_ADDR: &str = "https://firecache.spiceai.io";
// default address for local spice runtime
pub const SPICE_LOCAL_FLIGHT_ADDR: &str = "http://localhost:50051";

pub type GenericError = Box<dyn std::error::Error + Send + Sync>;

#[cfg(target_family = "unix")]
fn get_os_release() -> Result<String, Box<dyn std::error::Error>> {
fn get_os_release() -> Result<String, GenericError> {
// call uname -r to get release text
use std::process::Command;
let output = Command::new("uname").arg("-r").output()?;
Expand All @@ -15,7 +17,7 @@ fn get_os_release() -> Result<String, Box<dyn std::error::Error>> {
}

#[cfg(target_family = "windows")]
fn get_os_release() -> Result<String, Box<dyn std::error::Error>> {
fn get_os_release() -> Result<String, GenericError> {
use winver::WindowsVersion;
if let Some(version) = WindowsVersion::detect() {
Ok(version.to_string())
Expand Down
6 changes: 3 additions & 3 deletions src/flight.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::config::get_user_agent;
use crate::config::GenericError;
use arrow::error::ArrowError;
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::error::FlightError;
Expand All @@ -12,7 +13,6 @@ use bytes::Bytes;
use futures::stream;
use futures::TryStreamExt;
use std::collections::HashMap;
use std::error::Error;
use std::str::FromStr;
use tonic::metadata::AsciiMetadataKey;
use tonic::transport::Channel;
Expand Down Expand Up @@ -86,7 +86,7 @@ impl SqlFlightClient {
Ok(resp)
}

async fn authenticate(&mut self, api_key: &str) -> std::result::Result<(), Box<dyn Error>> {
async fn authenticate(&mut self, api_key: &str) -> std::result::Result<(), GenericError> {
if api_key.split('|').collect::<String>().len() < 2 {
return Err("Invalid API key format".into());
}
Expand Down Expand Up @@ -119,7 +119,7 @@ impl SqlFlightClient {
pub async fn query(
&mut self,
query: &str,
) -> std::result::Result<FlightRecordBatchStream, Box<dyn Error>> {
) -> std::result::Result<FlightRecordBatchStream, GenericError> {
let api_key = self.api_key.clone();
if let Some(api_key) = api_key {
self.authenticate(&api_key).await?;
Expand Down
7 changes: 4 additions & 3 deletions src/tls.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::error::Error;
use std::str::FromStr;
use tonic::transport::channel::{ClientTlsConfig, Endpoint};
use tonic::transport::Channel;

pub fn system_tls_certificate() -> Result<tonic::transport::Certificate, Box<dyn Error>> {
use crate::config::GenericError;

pub fn system_tls_certificate() -> Result<tonic::transport::Certificate, GenericError> {
// Load root certificates found in the platform’s native certificate store.
let certs = rustls_native_certs::load_native_certs()?;

Expand All @@ -19,7 +20,7 @@ pub fn system_tls_certificate() -> Result<tonic::transport::Certificate, Box<dyn
Ok(tonic::transport::Certificate::from_pem(concatenated_pems))
}

pub async fn new_tls_flight_channel(https_url: &str) -> Result<Channel, Box<dyn Error>> {
pub async fn new_tls_flight_channel(https_url: &str) -> Result<Channel, GenericError> {
let mut endpoint = Endpoint::from_str(https_url)?;

if https_url.starts_with("https://") {
Expand Down

0 comments on commit 2819dd9

Please sign in to comment.