diff --git a/Makefile b/Makefile index 8b2058bb39d..833de83550b 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,7 @@ # Lint rules to ignore LINTIGNORESINGLEFIGHT='internal/sync/singleflight/singleflight.go:.+error should be the last type' LINT_IGNORE_S3MANAGER_INPUT='feature/s3/manager/upload.go:.+struct field SSEKMSKeyId should be SSEKMSKeyID' +LINT_IGNORE_RECURSIONDETECTIONTEST='aws/middleware/recursion_detection_test.go:.+should not use basic type untyped string as key in context.WithValue' # Names of these are tied to endpoint rules and they're internal so ignore them LINT_IGNORE_AWSRULESFN_ARN='internal/endpoints/awsrulesfn/arn.go' LINT_IGNORE_AWSRULESFN_PARTITION='internal/endpoints/awsrulesfn/partition.go' @@ -430,6 +431,7 @@ lint: dolint=`echo "$$lint" | grep -E -v \ -e ${LINT_IGNORE_S3MANAGER_INPUT} \ -e ${LINTIGNORESINGLEFIGHT} \ + -e ${LINT_IGNORE_RECURSIONDETECTIONTEST} \ -e ${LINT_IGNORE_AWSRULESFN_ARN} \ -e ${LINT_IGNORE_AWSRULESFN_PARTITION} \ -e ${LINT_IGNORE_PRIVATE_METRICS}`; \ diff --git a/aws/middleware/recursion_detection.go b/aws/middleware/recursion_detection.go index 3f6aaf231e1..7a079083508 100644 --- a/aws/middleware/recursion_detection.go +++ b/aws/middleware/recursion_detection.go @@ -10,6 +10,7 @@ import ( const envAwsLambdaFunctionName = "AWS_LAMBDA_FUNCTION_NAME" const envAmznTraceID = "_X_AMZN_TRACE_ID" +const ctxKeyAmznTraceID = "x-amzn-trace-id" const amznTraceIDHeader = "X-Amzn-Trace-Id" // AddRecursionDetection adds recursionDetection to the middleware stack @@ -39,9 +40,12 @@ func (m *RecursionDetection) HandleBuild( _, hasLambdaEnv := os.LookupEnv(envAwsLambdaFunctionName) xAmznTraceID, hasTraceID := os.LookupEnv(envAmznTraceID) + if !hasTraceID { + xAmznTraceID, hasTraceID = ctx.Value(ctxKeyAmznTraceID).(string) + } value := req.Header.Get(amznTraceIDHeader) // only set the X-Amzn-Trace-Id header when it is not set initially, the - // current environment is Lambda and the _X_AMZN_TRACE_ID env variable exists + // current environment is Lambda and the x-amzn-trace-id value is present in the context if value != "" || !hasLambdaEnv || !hasTraceID { return next.HandleBuild(ctx, in) } diff --git a/aws/middleware/recursion_detection_test.go b/aws/middleware/recursion_detection_test.go index 0f4122231a2..3c457f4bfe2 100644 --- a/aws/middleware/recursion_detection_test.go +++ b/aws/middleware/recursion_detection_test.go @@ -11,33 +11,58 @@ import ( func TestRecursionDetection(t *testing.T) { cases := map[string]struct { LambdaFuncName string - TraceID string + EnvTraceID string + CtxTraceID string HeaderBefore string HeaderAfter string }{ "non lambda env and no trace ID header before": {}, - "with lambda env but no trace ID env variable, no trace ID header before": { + "with lambda env but no trace ID env value, no trace ID header before": { LambdaFuncName: "some-function1", }, - "with lambda env and trace ID env variable, no trace ID header before": { + "with lambda env and trace ID env value, no trace ID header before": { LambdaFuncName: "some-function2", - TraceID: "traceID1", + EnvTraceID: "traceID1", HeaderAfter: "traceID1", }, - "with lambda env and trace ID env variable, has trace ID header before": { + "with lambda env and trace ID env value, has trace ID header before": { LambdaFuncName: "some-function3", - TraceID: "traceID2", + EnvTraceID: "traceID2", HeaderBefore: "traceID1", HeaderAfter: "traceID1", }, - "with lambda env and trace ID (needs encoding) env variable, no trace ID header before": { + "with lambda env and trace ID (needs encoding) env value, no trace ID header before": { LambdaFuncName: "some-function4", - TraceID: "traceID3\n", + EnvTraceID: "traceID3\n", HeaderAfter: "traceID3%0A", }, - "with lambda env and trace ID (contains chars must not be encoded) env variable, no trace ID header before": { + "with lambda env and trace ID (contains chars must not be encoded) env value, no trace ID header before": { LambdaFuncName: "some-function5", - TraceID: "traceID4-=;:+&[]{}\"'", + EnvTraceID: "traceID4-=;:+&[]{}\"'", + HeaderAfter: "traceID4-=;:+&[]{}\"'", + }, + "with lambda env but no trace ID env value, no trace ID header before, with fallback trace ID in ctx": { + LambdaFuncName: "some-function1", + CtxTraceID: "traceIDEnv", + HeaderAfter: "traceIDEnv", + }, + "with lambda env and trace ID env value, has trace ID header before, with fallback trace ID in ctx": { + LambdaFuncName: "some-function3", + EnvTraceID: "traceID2", + CtxTraceID: "traceIDEnv", + HeaderBefore: "traceID1", + HeaderAfter: "traceID1", + }, + "with lambda env and trace ID (needs encoding) env value, no trace ID header before, with fallback trace ID in ctx": { + LambdaFuncName: "some-function4", + EnvTraceID: "traceID3\n", + CtxTraceID: "traceIDEnv", + HeaderAfter: "traceID3%0A", + }, + "with lambda env and trace ID (contains chars must not be encoded) env value, no trace ID header before, with fallback trace ID in ctx": { + LambdaFuncName: "some-function5", + EnvTraceID: "traceID4-=;:+&[]{}\"'", + CtxTraceID: "traceIDEnv", HeaderAfter: "traceID4-=;:+&[]{}\"'", }, } @@ -49,7 +74,11 @@ func TestRecursionDetection(t *testing.T) { defer restoreEnv() setEnvVar(t, envAwsLambdaFunctionName, c.LambdaFuncName) - setEnvVar(t, envAmznTraceID, c.TraceID) + setEnvVar(t, envAmznTraceID, c.EnvTraceID) + ctx := context.Background() + if c.CtxTraceID != "" { + ctx = context.WithValue(context.Background(), ctxKeyAmznTraceID, c.CtxTraceID) + } req := smithyhttp.NewStackRequest().(*smithyhttp.Request) if c.HeaderBefore != "" { @@ -57,7 +86,7 @@ func TestRecursionDetection(t *testing.T) { } var updatedRequest *smithyhttp.Request m := RecursionDetection{} - _, _, err := m.HandleBuild(context.Background(), + _, _, err := m.HandleBuild(ctx, smithymiddleware.BuildInput{Request: req}, smithymiddleware.BuildHandlerFunc(func(ctx context.Context, input smithymiddleware.BuildInput) ( out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error) {