diff --git a/credentials/static_provider.go b/credentials/static_provider.go index a469abdb790..4e361dfd1bf 100644 --- a/credentials/static_provider.go +++ b/credentials/static_provider.go @@ -2,6 +2,7 @@ package credentials import ( "context" + "strings" "github.com/aws/aws-sdk-go-v2/aws" ) @@ -18,6 +19,27 @@ func (*StaticCredentialsEmptyError) Error() string { return "static credentials are empty" } +// AccessKeyIDInvalidWhitespaceError is emitted when AccessKeyID contains invalid whitespace. +type AccessKeyIDInvalidWhitespaceError struct{} + +func (*AccessKeyIDInvalidWhitespaceError) Error() string { + return "AccessKeyID contains invalid whitespace" +} + +// SecretAccessKeyInvalidWhitespaceError is emitted when SecretAccessKey contains invalid whitespace. +type SecretAccessKeyInvalidWhitespaceError struct{} + +func (*SecretAccessKeyInvalidWhitespaceError) Error() string { + return "SecretAccessKey contains invalid whitespace" +} + +// SessionTokenInvalidWhitespaceError is emitted when SessionToken contains invalid whitespace. +type SessionTokenInvalidWhitespaceError struct{} + +func (*SessionTokenInvalidWhitespaceError) Error() string { + return "SessionToken contains invalid whitespace" +} + // A StaticCredentialsProvider is a set of credentials which are set, and will // never expire. type StaticCredentialsProvider struct { @@ -49,6 +71,23 @@ func NewStaticCredentialsProvider(key, secret, session string) StaticCredentials // Retrieve returns the credentials or error if the credentials are invalid. func (s StaticCredentialsProvider) Retrieve(_ context.Context) (aws.Credentials, error) { v := s.Value + + if strings.ContainsAny(v.AccessKeyID, " \t\r\n") { + return aws.Credentials{ + Source: StaticCredentialsName, + }, &AccessKeyIDInvalidWhitespaceError{} + } + if strings.ContainsAny(v.SecretAccessKey, " \t\r\n") { + return aws.Credentials{ + Source: StaticCredentialsName, + }, &SecretAccessKeyInvalidWhitespaceError{} + } + if strings.ContainsAny(v.SessionToken, " \t\r\n") { + return aws.Credentials{ + Source: StaticCredentialsName, + }, &SessionTokenInvalidWhitespaceError{} + } + if v.AccessKeyID == "" || v.SecretAccessKey == "" { return aws.Credentials{ Source: StaticCredentialsName, diff --git a/credentials/static_provider_test.go b/credentials/static_provider_test.go index 431e7e05431..a23a3a6a517 100644 --- a/credentials/static_provider_test.go +++ b/credentials/static_provider_test.go @@ -49,3 +49,102 @@ func TestStaticCredentialsProviderIsExpired(t *testing.T) { t.Errorf("expect static credentials to never expire") } } + +func TestStaticCredentialsProviderValidation(t *testing.T) { + tests := []struct { + name string + accessKey string + secretKey string + sessionToken string + shouldError bool + expectedMsg string + }{ + { + name: "trailing newline in secret (issue #3304)", + accessKey: "AKID", + secretKey: "SECRET\n", + sessionToken: "", + shouldError: true, + expectedMsg: "SecretAccessKey contains invalid whitespace", + }, + { + name: "trailing space in access key", + accessKey: "AKID ", + secretKey: "SECRET", + sessionToken: "", + shouldError: true, + expectedMsg: "AccessKeyID contains invalid whitespace", + }, + { + name: "leading whitespace in token", + accessKey: "AKID", + secretKey: "SECRET", + sessionToken: " TOKEN", + shouldError: true, + expectedMsg: "SessionToken contains invalid whitespace", + }, + { + name: "tabs in secret key", + accessKey: "AKID", + secretKey: "\tSECRET\r", + sessionToken: "", + shouldError: true, + expectedMsg: "SecretAccessKey contains invalid whitespace", + }, + { + name: "valid credentials without whitespace", + accessKey: "AKID", + secretKey: "SECRET", + sessionToken: "TOKEN", + shouldError: false, + }, + { + name: "empty access key", + accessKey: "", + secretKey: "SECRET", + sessionToken: "", + shouldError: true, + expectedMsg: "static credentials are empty", + }, + { + name: "empty secret key", + accessKey: "AKID", + secretKey: "", + sessionToken: "", + shouldError: true, + expectedMsg: "static credentials are empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewStaticCredentialsProvider(tt.accessKey, tt.secretKey, tt.sessionToken) + + creds, err := s.Retrieve(context.Background()) + + if tt.shouldError { + if err == nil { + t.Fatal("expected error for credentials with whitespace, got nil") + } + + if err.Error() != tt.expectedMsg { + t.Errorf("expected error %q, got %q", tt.expectedMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("expect no error, got %v", err) + } + + if e, a := tt.accessKey, creds.AccessKeyID; e != a { + t.Errorf("expect AccessKeyID %q, got %q", e, a) + } + if e, a := tt.secretKey, creds.SecretAccessKey; e != a { + t.Errorf("expect SecretAccessKey %q, got %q", e, a) + } + if e, a := tt.sessionToken, creds.SessionToken; e != a { + t.Errorf("expect SessionToken %q, got %q", e, a) + } + } + }) + } +}