Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions credentials/static_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package credentials

import (
"context"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
)
Expand All @@ -18,6 +19,27 @@ func (*StaticCredentialsEmptyError) Error() string {
return "static credentials are empty"
}

// AccessKeyIDInvalidWhitespaceError is emitted when AccessKeyID contains invalid whitespace.
type AccessKeyIDInvalidWhitespaceError struct{}

func (*AccessKeyIDInvalidWhitespaceError) Error() string {
return "AccessKeyID contains invalid whitespace"
}

// SecretAccessKeyInvalidWhitespaceError is emitted when SecretAccessKey contains invalid whitespace.
type SecretAccessKeyInvalidWhitespaceError struct{}

func (*SecretAccessKeyInvalidWhitespaceError) Error() string {
return "SecretAccessKey contains invalid whitespace"
}

// SessionTokenInvalidWhitespaceError is emitted when SessionToken contains invalid whitespace.
type SessionTokenInvalidWhitespaceError struct{}

func (*SessionTokenInvalidWhitespaceError) Error() string {
return "SessionToken contains invalid whitespace"
}

// A StaticCredentialsProvider is a set of credentials which are set, and will
// never expire.
type StaticCredentialsProvider struct {
Expand Down Expand Up @@ -49,6 +71,23 @@ func NewStaticCredentialsProvider(key, secret, session string) StaticCredentials
// Retrieve returns the credentials or error if the credentials are invalid.
func (s StaticCredentialsProvider) Retrieve(_ context.Context) (aws.Credentials, error) {
v := s.Value

if strings.ContainsAny(v.AccessKeyID, " \t\r\n") {
return aws.Credentials{
Source: StaticCredentialsName,
}, &AccessKeyIDInvalidWhitespaceError{}
}
if strings.ContainsAny(v.SecretAccessKey, " \t\r\n") {
return aws.Credentials{
Source: StaticCredentialsName,
}, &SecretAccessKeyInvalidWhitespaceError{}
}
if strings.ContainsAny(v.SessionToken, " \t\r\n") {
return aws.Credentials{
Source: StaticCredentialsName,
}, &SessionTokenInvalidWhitespaceError{}
}

if v.AccessKeyID == "" || v.SecretAccessKey == "" {
return aws.Credentials{
Source: StaticCredentialsName,
Expand Down
99 changes: 99 additions & 0 deletions credentials/static_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,102 @@ func TestStaticCredentialsProviderIsExpired(t *testing.T) {
t.Errorf("expect static credentials to never expire")
}
}

func TestStaticCredentialsProviderValidation(t *testing.T) {
tests := []struct {
name string
accessKey string
secretKey string
sessionToken string
shouldError bool
expectedMsg string
}{
{
name: "trailing newline in secret (issue #3304)",
accessKey: "AKID",
secretKey: "SECRET\n",
sessionToken: "",
shouldError: true,
expectedMsg: "SecretAccessKey contains invalid whitespace",
},
{
name: "trailing space in access key",
accessKey: "AKID ",
secretKey: "SECRET",
sessionToken: "",
shouldError: true,
expectedMsg: "AccessKeyID contains invalid whitespace",
},
{
name: "leading whitespace in token",
accessKey: "AKID",
secretKey: "SECRET",
sessionToken: " TOKEN",
shouldError: true,
expectedMsg: "SessionToken contains invalid whitespace",
},
{
name: "tabs in secret key",
accessKey: "AKID",
secretKey: "\tSECRET\r",
sessionToken: "",
shouldError: true,
expectedMsg: "SecretAccessKey contains invalid whitespace",
},
{
name: "valid credentials without whitespace",
accessKey: "AKID",
secretKey: "SECRET",
sessionToken: "TOKEN",
shouldError: false,
},
{
name: "empty access key",
accessKey: "",
secretKey: "SECRET",
sessionToken: "",
shouldError: true,
expectedMsg: "static credentials are empty",
},
{
name: "empty secret key",
accessKey: "AKID",
secretKey: "",
sessionToken: "",
shouldError: true,
expectedMsg: "static credentials are empty",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := NewStaticCredentialsProvider(tt.accessKey, tt.secretKey, tt.sessionToken)

creds, err := s.Retrieve(context.Background())

if tt.shouldError {
if err == nil {
t.Fatal("expected error for credentials with whitespace, got nil")
}

if err.Error() != tt.expectedMsg {
t.Errorf("expected error %q, got %q", tt.expectedMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("expect no error, got %v", err)
}

if e, a := tt.accessKey, creds.AccessKeyID; e != a {
t.Errorf("expect AccessKeyID %q, got %q", e, a)
}
if e, a := tt.secretKey, creds.SecretAccessKey; e != a {
t.Errorf("expect SecretAccessKey %q, got %q", e, a)
}
if e, a := tt.sessionToken, creds.SessionToken; e != a {
t.Errorf("expect SessionToken %q, got %q", e, a)
}
}
})
}
}