Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Lint rules to ignore
LINTIGNORESINGLEFIGHT='internal/sync/singleflight/singleflight.go:.+error should be the last type'
LINT_IGNORE_S3MANAGER_INPUT='feature/s3/manager/upload.go:.+struct field SSEKMSKeyId should be SSEKMSKeyID'
LINT_IGNORE_RECURSIONDETECTIONTEST='aws/middleware/recursion_detection_test.go:.+should not use basic type untyped string as key in context.WithValue'
# Names of these are tied to endpoint rules and they're internal so ignore them
LINT_IGNORE_AWSRULESFN_ARN='internal/endpoints/awsrulesfn/arn.go'
LINT_IGNORE_AWSRULESFN_PARTITION='internal/endpoints/awsrulesfn/partition.go'
Expand Down Expand Up @@ -430,6 +431,7 @@ lint:
dolint=`echo "$$lint" | grep -E -v \
-e ${LINT_IGNORE_S3MANAGER_INPUT} \
-e ${LINTIGNORESINGLEFIGHT} \
-e ${LINT_IGNORE_RECURSIONDETECTIONTEST} \
-e ${LINT_IGNORE_AWSRULESFN_ARN} \
-e ${LINT_IGNORE_AWSRULESFN_PARTITION} \
-e ${LINT_IGNORE_PRIVATE_METRICS}`; \
Expand Down
6 changes: 5 additions & 1 deletion aws/middleware/recursion_detection.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

const envAwsLambdaFunctionName = "AWS_LAMBDA_FUNCTION_NAME"
const envAmznTraceID = "_X_AMZN_TRACE_ID"
const ctxKeyAmznTraceID = "x-amzn-trace-id"
const amznTraceIDHeader = "X-Amzn-Trace-Id"

// AddRecursionDetection adds recursionDetection to the middleware stack
Expand Down Expand Up @@ -39,9 +40,12 @@ func (m *RecursionDetection) HandleBuild(

_, hasLambdaEnv := os.LookupEnv(envAwsLambdaFunctionName)
xAmznTraceID, hasTraceID := os.LookupEnv(envAmznTraceID)
if !hasTraceID {
xAmznTraceID, hasTraceID = ctx.Value(ctxKeyAmznTraceID).(string)
}
value := req.Header.Get(amznTraceIDHeader)
// only set the X-Amzn-Trace-Id header when it is not set initially, the
// current environment is Lambda and the _X_AMZN_TRACE_ID env variable exists
// current environment is Lambda and the x-amzn-trace-id value is present in the context
if value != "" || !hasLambdaEnv || !hasTraceID {
return next.HandleBuild(ctx, in)
}
Expand Down
53 changes: 41 additions & 12 deletions aws/middleware/recursion_detection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,58 @@ import (
func TestRecursionDetection(t *testing.T) {
cases := map[string]struct {
LambdaFuncName string
TraceID string
EnvTraceID string
CtxTraceID string
HeaderBefore string
HeaderAfter string
}{
"non lambda env and no trace ID header before": {},
"with lambda env but no trace ID env variable, no trace ID header before": {
"with lambda env but no trace ID env value, no trace ID header before": {
LambdaFuncName: "some-function1",
},
"with lambda env and trace ID env variable, no trace ID header before": {
"with lambda env and trace ID env value, no trace ID header before": {
LambdaFuncName: "some-function2",
TraceID: "traceID1",
EnvTraceID: "traceID1",
HeaderAfter: "traceID1",
},
"with lambda env and trace ID env variable, has trace ID header before": {
"with lambda env and trace ID env value, has trace ID header before": {
LambdaFuncName: "some-function3",
TraceID: "traceID2",
EnvTraceID: "traceID2",
HeaderBefore: "traceID1",
HeaderAfter: "traceID1",
},
"with lambda env and trace ID (needs encoding) env variable, no trace ID header before": {
"with lambda env and trace ID (needs encoding) env value, no trace ID header before": {
LambdaFuncName: "some-function4",
TraceID: "traceID3\n",
EnvTraceID: "traceID3\n",
HeaderAfter: "traceID3%0A",
},
"with lambda env and trace ID (contains chars must not be encoded) env variable, no trace ID header before": {
"with lambda env and trace ID (contains chars must not be encoded) env value, no trace ID header before": {
LambdaFuncName: "some-function5",
TraceID: "traceID4-=;:+&[]{}\"'",
EnvTraceID: "traceID4-=;:+&[]{}\"'",
HeaderAfter: "traceID4-=;:+&[]{}\"'",
},
"with lambda env but no trace ID env value, no trace ID header before, with fallback trace ID in ctx": {
LambdaFuncName: "some-function1",
CtxTraceID: "traceIDEnv",
HeaderAfter: "traceIDEnv",
},
"with lambda env and trace ID env value, has trace ID header before, with fallback trace ID in ctx": {
LambdaFuncName: "some-function3",
EnvTraceID: "traceID2",
CtxTraceID: "traceIDEnv",
HeaderBefore: "traceID1",
HeaderAfter: "traceID1",
},
"with lambda env and trace ID (needs encoding) env value, no trace ID header before, with fallback trace ID in ctx": {
LambdaFuncName: "some-function4",
EnvTraceID: "traceID3\n",
CtxTraceID: "traceIDEnv",
HeaderAfter: "traceID3%0A",
},
"with lambda env and trace ID (contains chars must not be encoded) env value, no trace ID header before, with fallback trace ID in ctx": {
LambdaFuncName: "some-function5",
EnvTraceID: "traceID4-=;:+&[]{}\"'",
CtxTraceID: "traceIDEnv",
HeaderAfter: "traceID4-=;:+&[]{}\"'",
},
}
Expand All @@ -49,15 +74,19 @@ func TestRecursionDetection(t *testing.T) {
defer restoreEnv()

setEnvVar(t, envAwsLambdaFunctionName, c.LambdaFuncName)
setEnvVar(t, envAmznTraceID, c.TraceID)
setEnvVar(t, envAmznTraceID, c.EnvTraceID)
ctx := context.Background()
if c.CtxTraceID != "" {
ctx = context.WithValue(context.Background(), ctxKeyAmznTraceID, c.CtxTraceID)
}

req := smithyhttp.NewStackRequest().(*smithyhttp.Request)
if c.HeaderBefore != "" {
req.Header.Set(amznTraceIDHeader, c.HeaderBefore)
}
var updatedRequest *smithyhttp.Request
m := RecursionDetection{}
_, _, err := m.HandleBuild(context.Background(),
_, _, err := m.HandleBuild(ctx,
smithymiddleware.BuildInput{Request: req},
smithymiddleware.BuildHandlerFunc(func(ctx context.Context, input smithymiddleware.BuildInput) (
out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error) {
Expand Down
Loading