99import software .amazon .awssdk .core .ResponseInputStream ;
1010import software .amazon .awssdk .core .async .AsyncRequestBody ;
1111import software .amazon .awssdk .core .async .AsyncResponseTransformer ;
12- import software .amazon .awssdk .http .async .SdkAsyncHttpClient ;
1312import software .amazon .awssdk .http .nio .netty .NettyNioAsyncHttpClient ;
13+ import software .amazon .awssdk .http .nio .netty .NettyNioAsyncHttpClient .Builder ;
1414import software .amazon .awssdk .regions .Region ;
1515import software .amazon .awssdk .services .s3 .S3AsyncClient ;
1616import software .amazon .awssdk .services .s3 .S3AsyncClientBuilder ;
@@ -118,6 +118,7 @@ public S3AccessIO(T dvObject, DataAccessRequest req, String driverId) {
118118 try {
119119 bucketName = getBucketName (driverId );
120120 minPartSize = getMinPartSize (driverId );
121+ credentialsProvider = getCredentialsProvider (driverId );
121122 s3 = getClient (driverId );
122123 tm = getTransferManager (driverId );
123124 endpoint = getConfigParam (CUSTOM_ENDPOINT_URL , "" );
@@ -268,22 +269,9 @@ public void open(DataAccessOption... options) throws IOException {
268269 int retries = 20 ;
269270 while (retries > 0 ) {
270271 try {
272+ // Since s3 is an S3AsyncClient, we need to call .get() to wait for the result.
271273 HeadObjectResponse headObjectResponse = s3
272- .headObject (HeadObjectRequest .builder ().bucket (bucketName ).key (key ).build ()).get (); // Since
273- // s3
274- // is
275- // an
276- // S3AsyncClient,
277- // we
278- // need
279- // to
280- // call
281- // .get()
282- // to
283- // wait
284- // for
285- // the
286- // result
274+ .headObject (HeadObjectRequest .builder ().bucket (bucketName ).key (key ).build ()).get ();
287275 contentLength = headObjectResponse .contentLength ();
288276 if (retries != 20 ) {
289277 logger .warning ("Success for key: " + key + " after " + ((20 - retries ) * 3 ) + " seconds" );
@@ -1014,9 +1002,6 @@ public String generateTemporaryDownloadUrl(String auxiliaryTag, String auxiliary
10141002 String fileName = auxiliaryFileName == null ? this .getDataFile ().getDisplayName () : auxiliaryFileName ;
10151003 String contentType = auxiliaryType == null ? this .getDataFile ().getContentType () : auxiliaryType ;
10161004
1017- // Get the stored credentials provider
1018- AwsCredentialsProvider credentialsProvider = driverCredentialsProviderMap .get (this .driverId );
1019-
10201005 // Create S3Presigner
10211006 S3Presigner s3Presigner = S3Presigner .builder ()
10221007 .region (Region .of (s3 .serviceClientConfiguration ().region ().toString ()))
@@ -1075,24 +1060,22 @@ private String generateTemporaryS3UploadUrl(String key, Date expiration) throws
10751060
10761061 Duration expirationDuration = Duration .between (Instant .now (), expiration .toInstant ());
10771062
1078- // Get the stored credentials provider
1079- AwsCredentialsProvider credentialsProvider = driverCredentialsProviderMap .get (this .driverId );
1080-
10811063 // Create S3Presigner
10821064 S3Presigner s3Presigner = S3Presigner .builder ()
10831065 .region (Region .of (s3 .serviceClientConfiguration ().region ().toString ()))
10841066 .credentialsProvider (credentialsProvider ).build ();
1085-
1067+ logger . info ( "Bucket when signing = " + bucketName );
10861068 PutObjectPresignRequest .Builder presignRequestBuilder = PutObjectPresignRequest .builder ()
1087- .signatureDuration (expirationDuration ). putObjectRequest ( req -> req . bucket ( bucketName ). key ( key )) ;
1069+ .signatureDuration (expirationDuration );
10881070
10891071 // Add tagging if not disabled
10901072 final boolean taggingDisabled = JvmSettings .DISABLE_S3_TAGGING .lookupOptional (Boolean .class , this .driverId )
10911073 .orElse (false );
10921074 if (!taggingDisabled ) {
1093- presignRequestBuilder .putObjectRequest (req -> req .tagging ("dv-state=temp" ));
1075+ presignRequestBuilder .putObjectRequest (req -> req .tagging ("dv-state=temp" ).bucket (bucketName ).key (key ));
1076+ } else {
1077+ presignRequestBuilder .putObjectRequest (req -> req .bucket (bucketName ).key (key ));
10941078 }
1095-
10961079 PutObjectPresignRequest presignRequest = presignRequestBuilder .build ();
10971080
10981081 PresignedPutObjectRequest presignedRequest ;
@@ -1136,9 +1119,6 @@ public JsonObjectBuilder generateTemporaryS3UploadUrls(String globalId, String s
11361119 } else {
11371120 JsonObjectBuilder urls = Json .createObjectBuilder ();
11381121
1139- // Get the stored credentials provider
1140- AwsCredentialsProvider credentialsProvider = driverCredentialsProviderMap .get (this .driverId );
1141-
11421122 // Create S3Client
11431123 S3Client s3Client = S3Client .builder ()
11441124 .region (Region .of (s3 .serviceClientConfiguration ().region ().toString ()))
@@ -1250,6 +1230,7 @@ private static S3TransferManager getTransferManager(String driverId) {
12501230
12511231 @ SuppressWarnings ("deprecation" )
12521232 private static S3AsyncClient getClient (String driverId ) {
1233+
12531234 if (driverClientMap .containsKey (driverId )) {
12541235 return driverClientMap .get (driverId );
12551236 } else {
@@ -1258,19 +1239,19 @@ private static S3AsyncClient getClient(String driverId) {
12581239
12591240 // Create a custom HTTP client with the desired pool size
12601241 Integer poolSize = Integer .getInteger ("dataverse.files." + driverId + ".connection-pool-size" , 256 );
1261- SdkAsyncHttpClient httpClient = NettyNioAsyncHttpClient .builder ().maxConcurrency (poolSize ). build ( );
1242+ Builder httpClientBuilder = NettyNioAsyncHttpClient .builder ().maxConcurrency (poolSize );
12621243
12631244 // Apply the custom HTTP client to the S3AsyncClientBuilder
1264- s3CB .httpClient ( httpClient );
1245+ s3CB .httpClientBuilder ( httpClientBuilder );
12651246
12661247 // Configure endpoint and region
12671248 String s3CEUrl = getConfigParamForDriver (driverId , CUSTOM_ENDPOINT_URL , "" );
12681249 String s3CERegion = getConfigParamForDriver (driverId , CUSTOM_ENDPOINT_REGION , "dataverse" );
12691250
12701251 if (!s3CEUrl .isEmpty ()) {
12711252 s3CB .endpointOverride (URI .create (s3CEUrl ));
1253+ s3CB .region (Region .of (s3CERegion ));
12721254 }
1273- s3CB .region (Region .of (s3CERegion ));
12741255
12751256 // Configure path style access
12761257 Boolean s3pathStyleAccess = Boolean
@@ -1286,8 +1267,7 @@ private static S3AsyncClient getClient(String driverId) {
12861267 s3CB .serviceConfiguration (S3Configuration .builder ().chunkedEncodingEnabled (s3chunkedEncoding ).build ());
12871268
12881269 // Configure credentials
1289- AwsCredentialsProviderChain credentialsProvider = buildCredentialsProviderChain (driverId );
1290- s3CB .credentialsProvider (credentialsProvider );
1270+ s3CB .credentialsProvider (getCredentialsProvider (driverId ));
12911271
12921272 // Build the client
12931273 S3AsyncClient client = s3CB .build ();
@@ -1296,39 +1276,46 @@ private static S3AsyncClient getClient(String driverId) {
12961276 }
12971277 }
12981278
1299- private static AwsCredentialsProviderChain buildCredentialsProviderChain (String driverId ) {
1300- List <AwsCredentialsProvider > providers = new ArrayList <>();
1279+ private static AwsCredentialsProvider getCredentialsProvider (String driverId ) {
1280+ if (driverCredentialsProviderMap .containsKey (driverId )) {
1281+ return driverCredentialsProviderMap .get (driverId );
1282+ } else {
1283+ List <AwsCredentialsProvider > providers = new ArrayList <>();
13011284
1302- String s3profile = getConfigParamForDriver (driverId , PROFILE );
1303- boolean allowInstanceCredentials = true ;
1285+ String s3profile = getConfigParamForDriver (driverId , PROFILE );
1286+ boolean allowInstanceCredentials = true ;
13041287
1305- if (s3profile != null ) {
1306- allowInstanceCredentials = false ;
1307- }
1288+ if (s3profile != null ) {
1289+ allowInstanceCredentials = false ;
1290+ }
13081291
1309- Optional <String > accessKey = config .getOptionalValue ("dataverse.files." + driverId + ".access-key" ,
1310- String .class );
1311- Optional <String > secretKey = config .getOptionalValue ("dataverse.files." + driverId + ".secret-key" ,
1312- String .class );
1292+ Optional <String > accessKey = config .getOptionalValue ("dataverse.files." + driverId + ".access-key" ,
1293+ String .class );
1294+ Optional <String > secretKey = config .getOptionalValue ("dataverse.files." + driverId + ".secret-key" ,
1295+ String .class );
1296+
1297+ if (accessKey .isPresent () && secretKey .isPresent ()) {
1298+ allowInstanceCredentials = false ;
1299+ providers .add (
1300+ StaticCredentialsProvider .create (AwsBasicCredentials .create (accessKey .get (), secretKey .get ())));
1301+ } else if (s3profile == null ) {
1302+ s3profile = "default" ;
1303+ }
13131304
1314- if (accessKey .isPresent () && secretKey .isPresent ()) {
1315- allowInstanceCredentials = false ;
1316- providers .add (
1317- StaticCredentialsProvider .create (AwsBasicCredentials .create (accessKey .get (), secretKey .get ())));
1318- } else if (s3profile == null ) {
1319- s3profile = "default" ;
1320- }
1305+ if (s3profile != null ) {
1306+ providers .add (ProfileCredentialsProvider .create (s3profile ));
1307+ }
13211308
1322- if ( s3profile != null ) {
1323- providers .add (ProfileCredentialsProvider .create (s3profile ));
1324- }
1309+ if ( allowInstanceCredentials ) {
1310+ providers .add (InstanceProfileCredentialsProvider .create ());
1311+ }
13251312
1326- if (allowInstanceCredentials ) {
1327- providers .add (InstanceProfileCredentialsProvider .create ());
1313+ Collections .reverse (providers );
1314+ AwsCredentialsProvider provider = AwsCredentialsProviderChain .builder ().credentialsProviders (providers )
1315+ .build ();
1316+ driverCredentialsProviderMap .put (driverId , provider );
1317+ return provider ;
13281318 }
1329-
1330- Collections .reverse (providers );
1331- return AwsCredentialsProviderChain .builder ().credentialsProviders (providers ).build ();
13321319 }
13331320
13341321 public void removeTempTag () throws IOException {
0 commit comments