Skip to content

Commit 11d2fe3

Browse files
authored
Expose credential provider (#4235)
1 parent ca0278d commit 11d2fe3

6 files changed

Lines changed: 137 additions & 82 deletions

File tree

object_store/src/aws/mod.rs

Lines changed: 86 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ use url::Url;
4747

4848
pub use crate::aws::checksum::Checksum;
4949
use crate::aws::client::{S3Client, S3Config};
50-
use crate::aws::credential::{
51-
AwsCredential, InstanceCredentialProvider, WebIdentityProvider,
52-
};
50+
use crate::aws::credential::{InstanceCredentialProvider, WebIdentityProvider};
5351
use crate::client::header::header_meta;
5452
use crate::client::{
5553
ClientConfigKey, CredentialProvider, StaticCredentialProvider,
@@ -85,7 +83,9 @@ const STRICT_PATH_ENCODE_SET: percent_encoding::AsciiSet = STRICT_ENCODE_SET.rem
8583

8684
const STORE: &str = "S3";
8785

88-
type AwsCredentialProvider = Arc<dyn CredentialProvider<Credential = AwsCredential>>;
86+
/// [`CredentialProvider`] for [`AmazonS3`]
87+
pub type AwsCredentialProvider = Arc<dyn CredentialProvider<Credential = AwsCredential>>;
88+
pub use credential::AwsCredential;
8989

9090
/// Default metadata endpoint
9191
static METADATA_ENDPOINT: &str = "http://169.254.169.254";
@@ -209,6 +209,13 @@ impl std::fmt::Display for AmazonS3 {
209209
}
210210
}
211211

212+
impl AmazonS3 {
213+
/// Returns the [`AwsCredentialProvider`] used by [`AmazonS3`]
214+
pub fn credentials(&self) -> &AwsCredentialProvider {
215+
&self.client.config().credentials
216+
}
217+
}
218+
212219
#[async_trait]
213220
impl ObjectStore for AmazonS3 {
214221
async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> {
@@ -424,6 +431,8 @@ pub struct AmazonS3Builder {
424431
profile: Option<String>,
425432
/// Client options
426433
client_options: ClientOptions,
434+
/// Credentials
435+
credentials: Option<AwsCredentialProvider>,
427436
}
428437

429438
/// Configuration keys for [`AmazonS3Builder`]
@@ -879,6 +888,12 @@ impl AmazonS3Builder {
879888
self
880889
}
881890

891+
/// Set the credential provider overriding any other options
892+
pub fn with_credentials(mut self, credentials: AwsCredentialProvider) -> Self {
893+
self.credentials = Some(credentials);
894+
self
895+
}
896+
882897
/// Sets what protocol is allowed. If `allow_http` is :
883898
/// * false (default): Only HTTPS are allowed
884899
/// * true: HTTP and HTTPS are allowed
@@ -992,7 +1007,7 @@ impl AmazonS3Builder {
9921007
self.parse_url(&url)?;
9931008
}
9941009

995-
let region = match (self.region.clone(), self.profile.clone()) {
1010+
let region = match (self.region, self.profile.clone()) {
9961011
(Some(region), _) => Some(region),
9971012
(None, Some(profile)) => profile_region(profile),
9981013
(None, None) => None,
@@ -1002,76 +1017,74 @@ impl AmazonS3Builder {
10021017
let region = region.context(MissingRegionSnafu)?;
10031018
let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?;
10041019

1005-
let credentials = match (self.access_key_id, self.secret_access_key, self.token) {
1006-
(Some(key_id), Some(secret_key), token) => {
1007-
info!("Using Static credential provider");
1008-
let credential = AwsCredential {
1009-
key_id,
1010-
secret_key,
1011-
token,
1012-
};
1013-
Arc::new(StaticCredentialProvider::new(credential)) as _
1014-
}
1015-
(None, Some(_), _) => return Err(Error::MissingAccessKeyId.into()),
1016-
(Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()),
1017-
// TODO: Replace with `AmazonS3Builder::credentials_from_env`
1018-
_ => match (
1019-
std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE"),
1020-
std::env::var("AWS_ROLE_ARN"),
1021-
) {
1022-
(Ok(token_path), Ok(role_arn)) => {
1023-
info!("Using WebIdentity credential provider");
1024-
1025-
let session_name = std::env::var("AWS_ROLE_SESSION_NAME")
1026-
.unwrap_or_else(|_| "WebIdentitySession".to_string());
1027-
1028-
let endpoint = format!("https://sts.{region}.amazonaws.com");
1029-
1030-
// Disallow non-HTTPs requests
1031-
let client = self
1032-
.client_options
1033-
.clone()
1034-
.with_allow_http(false)
1035-
.client()?;
1036-
1037-
let token = WebIdentityProvider {
1038-
token_path,
1039-
session_name,
1040-
role_arn,
1041-
endpoint,
1042-
};
1043-
1044-
Arc::new(TokenCredentialProvider::new(
1020+
let credentials = if let Some(credentials) = self.credentials {
1021+
credentials
1022+
} else if self.access_key_id.is_some() || self.secret_access_key.is_some() {
1023+
match (self.access_key_id, self.secret_access_key, self.token) {
1024+
(Some(key_id), Some(secret_key), token) => {
1025+
info!("Using Static credential provider");
1026+
let credential = AwsCredential {
1027+
key_id,
1028+
secret_key,
10451029
token,
1046-
client,
1047-
self.retry_config.clone(),
1048-
)) as _
1030+
};
1031+
Arc::new(StaticCredentialProvider::new(credential)) as _
10491032
}
1050-
_ => match self.profile {
1051-
Some(profile) => {
1052-
info!("Using profile \"{}\" credential provider", profile);
1053-
profile_credentials(profile, region.clone())?
1054-
}
1055-
None => {
1056-
info!("Using Instance credential provider");
1057-
1058-
let token = InstanceCredentialProvider {
1059-
cache: Default::default(),
1060-
imdsv1_fallback: self.imdsv1_fallback.get()?,
1061-
metadata_endpoint: self
1062-
.metadata_endpoint
1063-
.unwrap_or_else(|| METADATA_ENDPOINT.into()),
1064-
};
1065-
1066-
Arc::new(TokenCredentialProvider::new(
1067-
token,
1068-
// The instance metadata endpoint is access over HTTP
1069-
self.client_options.clone().with_allow_http(true).client()?,
1070-
self.retry_config.clone(),
1071-
)) as _
1072-
}
1073-
},
1074-
},
1033+
(None, Some(_), _) => return Err(Error::MissingAccessKeyId.into()),
1034+
(Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()),
1035+
(None, None, _) => unreachable!(),
1036+
}
1037+
} else if let (Ok(token_path), Ok(role_arn)) = (
1038+
std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE"),
1039+
std::env::var("AWS_ROLE_ARN"),
1040+
) {
1041+
// TODO: Replace with `AmazonS3Builder::credentials_from_env`
1042+
info!("Using WebIdentity credential provider");
1043+
1044+
let session_name = std::env::var("AWS_ROLE_SESSION_NAME")
1045+
.unwrap_or_else(|_| "WebIdentitySession".to_string());
1046+
1047+
let endpoint = format!("https://sts.{region}.amazonaws.com");
1048+
1049+
// Disallow non-HTTPs requests
1050+
let client = self
1051+
.client_options
1052+
.clone()
1053+
.with_allow_http(false)
1054+
.client()?;
1055+
1056+
let token = WebIdentityProvider {
1057+
token_path,
1058+
session_name,
1059+
role_arn,
1060+
endpoint,
1061+
};
1062+
1063+
Arc::new(TokenCredentialProvider::new(
1064+
token,
1065+
client,
1066+
self.retry_config.clone(),
1067+
)) as _
1068+
} else if let Some(profile) = self.profile {
1069+
info!("Using profile \"{}\" credential provider", profile);
1070+
profile_credentials(profile, region.clone())?
1071+
} else {
1072+
info!("Using Instance credential provider");
1073+
1074+
let token = InstanceCredentialProvider {
1075+
cache: Default::default(),
1076+
imdsv1_fallback: self.imdsv1_fallback.get()?,
1077+
metadata_endpoint: self
1078+
.metadata_endpoint
1079+
.unwrap_or_else(|| METADATA_ENDPOINT.into()),
1080+
};
1081+
1082+
Arc::new(TokenCredentialProvider::new(
1083+
token,
1084+
// The instance metadata endpoint is access over HTTP
1085+
self.client_options.clone().with_allow_http(true).client()?,
1086+
self.retry_config.clone(),
1087+
)) as _
10751088
};
10761089

10771090
let endpoint: String;

object_store/src/azure/mod.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ use std::{collections::BTreeSet, str::FromStr};
4848
use tokio::io::AsyncWrite;
4949
use url::Url;
5050

51-
use crate::azure::credential::AzureCredential;
5251
use crate::client::header::header_meta;
5352
use crate::client::{
5453
ClientConfigKey, CredentialProvider, StaticCredentialProvider,
@@ -60,7 +59,10 @@ pub use credential::authority_hosts;
6059
mod client;
6160
mod credential;
6261

63-
type AzureCredentialProvider = Arc<dyn CredentialProvider<Credential = AzureCredential>>;
62+
/// [`CredentialProvider`] for [`MicrosoftAzure`]
63+
pub type AzureCredentialProvider =
64+
Arc<dyn CredentialProvider<Credential = AzureCredential>>;
65+
pub use credential::AzureCredential;
6466

6567
const STORE: &str = "MicrosoftAzure";
6668

@@ -153,6 +155,13 @@ pub struct MicrosoftAzure {
153155
client: Arc<client::AzureClient>,
154156
}
155157

158+
impl MicrosoftAzure {
159+
/// Returns the [`AzureCredentialProvider`] used by [`MicrosoftAzure`]
160+
pub fn credentials(&self) -> &AzureCredentialProvider {
161+
&self.client.config().credentials
162+
}
163+
}
164+
156165
impl std::fmt::Display for MicrosoftAzure {
157166
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158167
write!(
@@ -374,6 +383,8 @@ pub struct MicrosoftAzureBuilder {
374383
retry_config: RetryConfig,
375384
/// Client options
376385
client_options: ClientOptions,
386+
/// Credentials
387+
credentials: Option<AzureCredentialProvider>,
377388
}
378389

379390
/// Configuration keys for [`MicrosoftAzureBuilder`]
@@ -840,6 +851,12 @@ impl MicrosoftAzureBuilder {
840851
self
841852
}
842853

854+
/// Set the credential provider overriding any other options
855+
pub fn with_credentials(mut self, credentials: AzureCredentialProvider) -> Self {
856+
self.credentials = Some(credentials);
857+
self
858+
}
859+
843860
/// Set if the Azure emulator should be used (defaults to false)
844861
pub fn with_use_emulator(mut self, use_emulator: bool) -> Self {
845862
self.use_emulator = use_emulator.into();
@@ -937,7 +954,9 @@ impl MicrosoftAzureBuilder {
937954
let url = Url::parse(&account_url)
938955
.context(UnableToParseUrlSnafu { url: account_url })?;
939956

940-
let credential = if let Some(bearer_token) = self.bearer_token {
957+
let credential = if let Some(credential) = self.credentials {
958+
credential
959+
} else if let Some(bearer_token) = self.bearer_token {
941960
static_creds(AzureCredential::BearerToken(bearer_token))
942961
} else if let Some(access_key) = self.access_key {
943962
static_creds(AzureCredential::AccessKey(access_key))

object_store/src/client/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,10 @@ impl GetOptionsExt for RequestBuilder {
509509
/// Provides credentials for use when signing requests
510510
#[async_trait]
511511
pub trait CredentialProvider: std::fmt::Debug + Send + Sync {
512+
/// The type of credential returned by this provider
512513
type Credential;
513514

515+
/// Return a credential
514516
async fn get_credential(&self) -> Result<Arc<Self::Credential>>;
515517
}
516518

object_store/src/gcp/credential.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ impl From<Error> for crate::Error {
8282
}
8383
}
8484

85+
/// A Google Cloud Storage Credential
8586
#[derive(Debug, Eq, PartialEq)]
8687
pub struct GcpCredential {
8788
/// An HTTP bearer token

object_store/src/gcp/mod.rs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ use crate::client::{
5252
ClientConfigKey, CredentialProvider, GetOptionsExt, StaticCredentialProvider,
5353
TokenCredentialProvider,
5454
};
55-
use crate::gcp::credential::{application_default_credentials, GcpCredential};
5655
use crate::{
5756
multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart},
5857
path::{Path, DELIMITER},
@@ -61,15 +60,18 @@ use crate::{
6160
ObjectStore, Result, RetryConfig,
6261
};
6362

64-
use self::credential::{
65-
default_gcs_base_url, InstanceCredentialProvider, ServiceAccountCredentials,
63+
use credential::{
64+
application_default_credentials, default_gcs_base_url, InstanceCredentialProvider,
65+
ServiceAccountCredentials,
6666
};
6767

6868
mod credential;
6969

7070
const STORE: &str = "GCS";
7171

72-
type GcpCredentialProvider = Arc<dyn CredentialProvider<Credential = GcpCredential>>;
72+
/// [`CredentialProvider`] for [`GoogleCloudStorage`]
73+
pub type GcpCredentialProvider = Arc<dyn CredentialProvider<Credential = GcpCredential>>;
74+
pub use credential::GcpCredential;
7375

7476
#[derive(Debug, Snafu)]
7577
enum Error {
@@ -205,6 +207,13 @@ impl std::fmt::Display for GoogleCloudStorage {
205207
}
206208
}
207209

210+
impl GoogleCloudStorage {
211+
/// Returns the [`GcpCredentialProvider`] used by [`GoogleCloudStorage`]
212+
pub fn credentials(&self) -> &GcpCredentialProvider {
213+
&self.client.credentials
214+
}
215+
}
216+
208217
#[derive(Debug)]
209218
struct GoogleCloudStorageClient {
210219
client: Client,
@@ -696,6 +705,8 @@ pub struct GoogleCloudStorageBuilder {
696705
retry_config: RetryConfig,
697706
/// Client options
698707
client_options: ClientOptions,
708+
/// Credentials
709+
credentials: Option<GcpCredentialProvider>,
699710
}
700711

701712
/// Configuration keys for [`GoogleCloudStorageBuilder`]
@@ -794,6 +805,7 @@ impl Default for GoogleCloudStorageBuilder {
794805
retry_config: Default::default(),
795806
client_options: ClientOptions::new().with_allow_http(true),
796807
url: None,
808+
credentials: None,
797809
}
798810
}
799811
}
@@ -1006,6 +1018,12 @@ impl GoogleCloudStorageBuilder {
10061018
self
10071019
}
10081020

1021+
/// Set the credential provider overriding any other options
1022+
pub fn with_credentials(mut self, credentials: GcpCredentialProvider) -> Self {
1023+
self.credentials = Some(credentials);
1024+
self
1025+
}
1026+
10091027
/// Set the retry configuration
10101028
pub fn with_retry(mut self, retry_config: RetryConfig) -> Self {
10111029
self.retry_config = retry_config;
@@ -1072,7 +1090,9 @@ impl GoogleCloudStorageBuilder {
10721090
let scope = "https://www.googleapis.com/auth/devstorage.full_control";
10731091
let audience = "https://www.googleapis.com/oauth2/v4/token";
10741092

1075-
let credentials = if disable_oauth {
1093+
let credentials = if let Some(credentials) = self.credentials {
1094+
credentials
1095+
} else if disable_oauth {
10761096
Arc::new(StaticCredentialProvider::new(GcpCredential {
10771097
bearer: "".to_string(),
10781098
})) as _

object_store/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ pub mod throttle;
245245
mod client;
246246

247247
#[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature = "http"))]
248-
pub use client::{backoff::BackoffConfig, retry::RetryConfig};
248+
pub use client::{backoff::BackoffConfig, retry::RetryConfig, CredentialProvider};
249249

250250
#[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature = "http"))]
251251
mod config;

0 commit comments

Comments
 (0)