diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpChecksumGenerator.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpChecksumGenerator.java index 43e5001042a..3df36a00e8c 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpChecksumGenerator.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpChecksumGenerator.java @@ -103,6 +103,18 @@ public void processFinalizedModel(GoSettings settings, Model model) { .build()) .build()); + var inputhecksumConfigs = RuntimeClientPlugin.builder() + .servicePredicate(this::supportsTrailingChecksum) + .addConfigField( + ConfigField.builder() + .name("DisableTrailingChecksumSupport") + .type(GoUniverseTypes.Bool) + .documentation("Allow disabling trailing headers for server implementations not supporting this") + .build() + ) + .build(); + runtimeClientPlugins.add(inputhecksumConfigs); + // Output helper String outputHelperFuncName = getAddOutputMiddlewareFuncName( symbolProvider.toSymbol(operation).getName() @@ -248,27 +260,12 @@ private void writeInputMiddlewareHelper( ) { Symbol operationSymbol = symbolProvider.toSymbol(operation); String operationName = operationSymbol.getName(); - StructureShape input = model.expectShape(operation.getInput().get(), StructureShape.class); HttpChecksumTrait trait = operation.expectTrait(HttpChecksumTrait.class); boolean isRequestChecksumRequired = trait.isRequestChecksumRequired(); boolean hasRequestAlgorithmMember = trait.getRequestAlgorithmMember().isPresent(); - boolean supportsTrailingChecksum = false; - for (MemberShape memberShape : input.getAllMembers().values()) { - Shape targetShape = model.expectShape(memberShape.getTarget()); - if (targetShape.hasTrait(StreamingTrait.class) && - !StreamingTrait.isEventStream(model, memberShape) - ) { - if (isS3ServiceShape(model, service) || ( - AwsSignatureVersion4.hasSigV4AuthScheme(model, service, operation) - && !operation.hasTrait(UnsignedPayloadTrait.class))) { - supportsTrailingChecksum = true; - } - } - } - - boolean supportsRequestTrailingChecksum = supportsTrailingChecksum; + boolean supportsRequestTrailingChecksum = supportsTrailingChecksum(model, service, operation); boolean supportsDecodedContentLengthHeader = isS3ServiceShape(model, service); // imports @@ -290,13 +287,41 @@ private void writeInputMiddlewareHelper( hasRequestAlgorithmMember ? getRequestAlgorithmAccessorFuncName(operationName) : "nil", isRequestChecksumRequired, - supportsRequestTrailingChecksum, + supportsRequestTrailingChecksum ? "!options.DisableTrailingChecksumSupport" : "false", supportsDecodedContentLengthHeader); } ); writer.insertTrailingNewline(); } + private boolean supportsTrailingChecksum(Model model, ServiceShape service) { + for (OperationShape operation : TopDownIndex.of(model).getContainedOperations(service)) { + if (supportsTrailingChecksum(model, service, operation)) { + return true; + } + } + return false; + } + + private boolean supportsTrailingChecksum(Model model, ServiceShape service, OperationShape operation) { + StructureShape input = model.expectShape(operation.getInput().get(), StructureShape.class); + + boolean supportsTrailingChecksum = false; + for (MemberShape memberShape : input.getAllMembers().values()) { + Shape targetShape = model.expectShape(memberShape.getTarget()); + if (targetShape.hasTrait(StreamingTrait.class) && + !StreamingTrait.isEventStream(model, memberShape) + ) { + if (isS3ServiceShape(model, service) || ( + AwsSignatureVersion4.hasSigV4AuthScheme(model, service, operation) + && !operation.hasTrait(UnsignedPayloadTrait.class))) { + supportsTrailingChecksum = true; + } + } + } + return supportsTrailingChecksum; + } + // adapted (service/internal/checksum).AddInputMiddleware to give the service client control over its middleware stack, // per #2507 private void writePackageLevelAddInputChecksumMiddleware(GoWriter writer) {