Skip to content

Commit 3e9156f

Browse files
committed
Fix bugs
1 parent 6997b09 commit 3e9156f

1 file changed

Lines changed: 47 additions & 60 deletions

File tree

src/main/java/edu/harvard/iq/dataverse/dataaccess/S3AccessIO.java

Lines changed: 47 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import software.amazon.awssdk.core.ResponseInputStream;
1010
import software.amazon.awssdk.core.async.AsyncRequestBody;
1111
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
12-
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
1312
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
13+
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient.Builder;
1414
import software.amazon.awssdk.regions.Region;
1515
import software.amazon.awssdk.services.s3.S3AsyncClient;
1616
import software.amazon.awssdk.services.s3.S3AsyncClientBuilder;
@@ -118,6 +118,7 @@ public S3AccessIO(T dvObject, DataAccessRequest req, String driverId) {
118118
try {
119119
bucketName = getBucketName(driverId);
120120
minPartSize = getMinPartSize(driverId);
121+
credentialsProvider = getCredentialsProvider(driverId);
121122
s3 = getClient(driverId);
122123
tm = getTransferManager(driverId);
123124
endpoint = getConfigParam(CUSTOM_ENDPOINT_URL, "");
@@ -268,22 +269,9 @@ public void open(DataAccessOption... options) throws IOException {
268269
int retries = 20;
269270
while (retries > 0) {
270271
try {
272+
// Since s3 is an S3AsyncClient, we need to call .get() to wait for the result.
271273
HeadObjectResponse headObjectResponse = s3
272-
.headObject(HeadObjectRequest.builder().bucket(bucketName).key(key).build()).get(); // Since
273-
// s3
274-
// is
275-
// an
276-
// S3AsyncClient,
277-
// we
278-
// need
279-
// to
280-
// call
281-
// .get()
282-
// to
283-
// wait
284-
// for
285-
// the
286-
// result
274+
.headObject(HeadObjectRequest.builder().bucket(bucketName).key(key).build()).get();
287275
contentLength = headObjectResponse.contentLength();
288276
if (retries != 20) {
289277
logger.warning("Success for key: " + key + " after " + ((20 - retries) * 3) + " seconds");
@@ -1014,9 +1002,6 @@ public String generateTemporaryDownloadUrl(String auxiliaryTag, String auxiliary
10141002
String fileName = auxiliaryFileName == null ? this.getDataFile().getDisplayName() : auxiliaryFileName;
10151003
String contentType = auxiliaryType == null ? this.getDataFile().getContentType() : auxiliaryType;
10161004

1017-
// Get the stored credentials provider
1018-
AwsCredentialsProvider credentialsProvider = driverCredentialsProviderMap.get(this.driverId);
1019-
10201005
// Create S3Presigner
10211006
S3Presigner s3Presigner = S3Presigner.builder()
10221007
.region(Region.of(s3.serviceClientConfiguration().region().toString()))
@@ -1075,24 +1060,22 @@ private String generateTemporaryS3UploadUrl(String key, Date expiration) throws
10751060

10761061
Duration expirationDuration = Duration.between(Instant.now(), expiration.toInstant());
10771062

1078-
// Get the stored credentials provider
1079-
AwsCredentialsProvider credentialsProvider = driverCredentialsProviderMap.get(this.driverId);
1080-
10811063
// Create S3Presigner
10821064
S3Presigner s3Presigner = S3Presigner.builder()
10831065
.region(Region.of(s3.serviceClientConfiguration().region().toString()))
10841066
.credentialsProvider(credentialsProvider).build();
1085-
1067+
logger.info("Bucket when signing = " + bucketName);
10861068
PutObjectPresignRequest.Builder presignRequestBuilder = PutObjectPresignRequest.builder()
1087-
.signatureDuration(expirationDuration).putObjectRequest(req -> req.bucket(bucketName).key(key));
1069+
.signatureDuration(expirationDuration);
10881070

10891071
// Add tagging if not disabled
10901072
final boolean taggingDisabled = JvmSettings.DISABLE_S3_TAGGING.lookupOptional(Boolean.class, this.driverId)
10911073
.orElse(false);
10921074
if (!taggingDisabled) {
1093-
presignRequestBuilder.putObjectRequest(req -> req.tagging("dv-state=temp"));
1075+
presignRequestBuilder.putObjectRequest(req -> req.tagging("dv-state=temp").bucket(bucketName).key(key));
1076+
} else {
1077+
presignRequestBuilder.putObjectRequest(req -> req.bucket(bucketName).key(key));
10941078
}
1095-
10961079
PutObjectPresignRequest presignRequest = presignRequestBuilder.build();
10971080

10981081
PresignedPutObjectRequest presignedRequest;
@@ -1136,9 +1119,6 @@ public JsonObjectBuilder generateTemporaryS3UploadUrls(String globalId, String s
11361119
} else {
11371120
JsonObjectBuilder urls = Json.createObjectBuilder();
11381121

1139-
// Get the stored credentials provider
1140-
AwsCredentialsProvider credentialsProvider = driverCredentialsProviderMap.get(this.driverId);
1141-
11421122
// Create S3Client
11431123
S3Client s3Client = S3Client.builder()
11441124
.region(Region.of(s3.serviceClientConfiguration().region().toString()))
@@ -1250,6 +1230,7 @@ private static S3TransferManager getTransferManager(String driverId) {
12501230

12511231
@SuppressWarnings("deprecation")
12521232
private static S3AsyncClient getClient(String driverId) {
1233+
12531234
if (driverClientMap.containsKey(driverId)) {
12541235
return driverClientMap.get(driverId);
12551236
} else {
@@ -1258,19 +1239,19 @@ private static S3AsyncClient getClient(String driverId) {
12581239

12591240
// Create a custom HTTP client with the desired pool size
12601241
Integer poolSize = Integer.getInteger("dataverse.files." + driverId + ".connection-pool-size", 256);
1261-
SdkAsyncHttpClient httpClient = NettyNioAsyncHttpClient.builder().maxConcurrency(poolSize).build();
1242+
Builder httpClientBuilder = NettyNioAsyncHttpClient.builder().maxConcurrency(poolSize);
12621243

12631244
// Apply the custom HTTP client to the S3AsyncClientBuilder
1264-
s3CB.httpClient(httpClient);
1245+
s3CB.httpClientBuilder(httpClientBuilder);
12651246

12661247
// Configure endpoint and region
12671248
String s3CEUrl = getConfigParamForDriver(driverId, CUSTOM_ENDPOINT_URL, "");
12681249
String s3CERegion = getConfigParamForDriver(driverId, CUSTOM_ENDPOINT_REGION, "dataverse");
12691250

12701251
if (!s3CEUrl.isEmpty()) {
12711252
s3CB.endpointOverride(URI.create(s3CEUrl));
1253+
s3CB.region(Region.of(s3CERegion));
12721254
}
1273-
s3CB.region(Region.of(s3CERegion));
12741255

12751256
// Configure path style access
12761257
Boolean s3pathStyleAccess = Boolean
@@ -1286,8 +1267,7 @@ private static S3AsyncClient getClient(String driverId) {
12861267
s3CB.serviceConfiguration(S3Configuration.builder().chunkedEncodingEnabled(s3chunkedEncoding).build());
12871268

12881269
// Configure credentials
1289-
AwsCredentialsProviderChain credentialsProvider = buildCredentialsProviderChain(driverId);
1290-
s3CB.credentialsProvider(credentialsProvider);
1270+
s3CB.credentialsProvider(getCredentialsProvider(driverId));
12911271

12921272
// Build the client
12931273
S3AsyncClient client = s3CB.build();
@@ -1296,39 +1276,46 @@ private static S3AsyncClient getClient(String driverId) {
12961276
}
12971277
}
12981278

1299-
private static AwsCredentialsProviderChain buildCredentialsProviderChain(String driverId) {
1300-
List<AwsCredentialsProvider> providers = new ArrayList<>();
1279+
private static AwsCredentialsProvider getCredentialsProvider(String driverId) {
1280+
if (driverCredentialsProviderMap.containsKey(driverId)) {
1281+
return driverCredentialsProviderMap.get(driverId);
1282+
} else {
1283+
List<AwsCredentialsProvider> providers = new ArrayList<>();
13011284

1302-
String s3profile = getConfigParamForDriver(driverId, PROFILE);
1303-
boolean allowInstanceCredentials = true;
1285+
String s3profile = getConfigParamForDriver(driverId, PROFILE);
1286+
boolean allowInstanceCredentials = true;
13041287

1305-
if (s3profile != null) {
1306-
allowInstanceCredentials = false;
1307-
}
1288+
if (s3profile != null) {
1289+
allowInstanceCredentials = false;
1290+
}
13081291

1309-
Optional<String> accessKey = config.getOptionalValue("dataverse.files." + driverId + ".access-key",
1310-
String.class);
1311-
Optional<String> secretKey = config.getOptionalValue("dataverse.files." + driverId + ".secret-key",
1312-
String.class);
1292+
Optional<String> accessKey = config.getOptionalValue("dataverse.files." + driverId + ".access-key",
1293+
String.class);
1294+
Optional<String> secretKey = config.getOptionalValue("dataverse.files." + driverId + ".secret-key",
1295+
String.class);
1296+
1297+
if (accessKey.isPresent() && secretKey.isPresent()) {
1298+
allowInstanceCredentials = false;
1299+
providers.add(
1300+
StaticCredentialsProvider.create(AwsBasicCredentials.create(accessKey.get(), secretKey.get())));
1301+
} else if (s3profile == null) {
1302+
s3profile = "default";
1303+
}
13131304

1314-
if (accessKey.isPresent() && secretKey.isPresent()) {
1315-
allowInstanceCredentials = false;
1316-
providers.add(
1317-
StaticCredentialsProvider.create(AwsBasicCredentials.create(accessKey.get(), secretKey.get())));
1318-
} else if (s3profile == null) {
1319-
s3profile = "default";
1320-
}
1305+
if (s3profile != null) {
1306+
providers.add(ProfileCredentialsProvider.create(s3profile));
1307+
}
13211308

1322-
if (s3profile != null) {
1323-
providers.add(ProfileCredentialsProvider.create(s3profile));
1324-
}
1309+
if (allowInstanceCredentials) {
1310+
providers.add(InstanceProfileCredentialsProvider.create());
1311+
}
13251312

1326-
if (allowInstanceCredentials) {
1327-
providers.add(InstanceProfileCredentialsProvider.create());
1313+
Collections.reverse(providers);
1314+
AwsCredentialsProvider provider = AwsCredentialsProviderChain.builder().credentialsProviders(providers)
1315+
.build();
1316+
driverCredentialsProviderMap.put(driverId, provider);
1317+
return provider;
13281318
}
1329-
1330-
Collections.reverse(providers);
1331-
return AwsCredentialsProviderChain.builder().credentialsProviders(providers).build();
13321319
}
13331320

13341321
public void removeTempTag() throws IOException {

0 commit comments

Comments
 (0)