diff --git a/feature/rds/auth/connect.go b/feature/rds/auth/connect.go index 4169095070e..164a08444da 100644 --- a/feature/rds/auth/connect.go +++ b/feature/rds/auth/connect.go @@ -69,6 +69,7 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds if err != nil { return "", err } + req.URL.Path = "/" values := req.URL.Query() values.Set("Action", "connect") values.Set("DBUser", dbUser) diff --git a/feature/rds/auth/connect_test.go b/feature/rds/auth/connect_test.go index 7ccb2332f86..02e007b3b65 100644 --- a/feature/rds/auth/connect_test.go +++ b/feature/rds/auth/connect_test.go @@ -22,13 +22,13 @@ func TestBuildAuthToken(t *testing.T) { endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306", region: "us-west-2", user: "mysqlUser", - expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`, + expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306/\?Action=connect.*?DBUser=mysqlUser.*`, }, { endpoint: "prod-instance.us-east-1.rds.amazonaws.com:3306", region: "us-west-2", user: "mysqlUser", - expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`, + expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306/\?Action=connect.*?DBUser=mysqlUser.*`, }, { endpoint: "prod-instance.us-east-1.rds.amazonaws.com", @@ -67,6 +67,40 @@ func TestBuildAuthToken(t *testing.T) { } } +func TestBuildAuthTokenPath(t *testing.T) { + cases := []struct { + name string + endpoint string + }{ + { + name: "endpoint without scheme", + endpoint: "prod-instance.us-east-1.rds.amazonaws.com:3306", + }, + { + name: "endpoint with https scheme", + endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306", + }, + { + name: "postgresql default port", + endpoint: "aurora-cluster.cluster-xxx.us-east-1.rds.amazonaws.com:5432", + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + creds := &staticCredentials{AccessKey: "AKID", SecretKey: "SECRET", Session: "SESSION"} + token, err := auth.BuildAuthToken(context.Background(), c.endpoint, "us-east-1", "dbUser", creds) + if err != nil { + t.Fatalf("expect no err, got: %v", err) + } + + if !strings.Contains(token, "/?") { + t.Errorf("expect token to contain '/?' path separator, got: %s", token) + } + }) + } +} + type staticCredentials struct { AccessKey, SecretKey, Session string }