Skip to content

Commit 93413f8

Browse files
BenjaminKazemicopybara-github
authored andcommitted
chore(vertexai): Test the NewGenAIClient function
FUTURE_COPYBARA_INTEGRATE_REVIEW=#14351 from renovate-bot:renovate/main-deps 2e7e530 PiperOrigin-RevId: 895405821
1 parent 5e51e19 commit 93413f8

2 files changed

Lines changed: 56 additions & 0 deletions

File tree

vertexai/genai/client.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,10 +461,19 @@ func float32pToInt32p(x *float32) *int32 {
461461

462462
// NewGenAIClient creates a new Google Vertex AI client and configures the the GenAI components.
463463
func NewGenAIClient(ctx context.Context, cc *genai.ClientConfig) (*Client, error) {
464+
if cc == nil {
465+
cc = &genai.ClientConfig{Backend: genai.BackendVertexAI}
466+
}
467+
if cc.Backend == genai.BackendUnspecified {
468+
cc.Backend = genai.BackendVertexAI
469+
}
464470
ac, err := genai.NewInternalAPIClient(ctx, cc)
465471
if err != nil {
466472
return nil, err
467473
}
474+
if ac.ClientConfig().Backend != genai.BackendVertexAI {
475+
return nil, fmt.Errorf("only Vertex AI backend is supported")
476+
}
468477
return &Client{
469478
AgentEngines: &clientAgentEngines{
470479
AgentEngines: AgentEngines{apiClient: ac},

vertexai/genai/client_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"testing"
2929

3030
"google.golang.org/api/iterator"
31+
"google.golang.org/genai"
3132
)
3233

3334
const defaultModel = "gemini-1.0-pro"
@@ -739,3 +740,49 @@ func printResponse(resp *GenerateContentResponse) {
739740
}
740741
fmt.Println("---")
741742
}
743+
744+
func TestNewGenAIClient(t *testing.T) {
745+
ctx := context.Background()
746+
key := "SECRET_TOKEN"
747+
originalValue, exists := os.LookupEnv(key)
748+
os.Unsetenv(key)
749+
// Restore later
750+
t.Cleanup(func() {
751+
if exists {
752+
os.Setenv(key, originalValue)
753+
}
754+
})
755+
for _, test := range []struct {
756+
name string
757+
cc *genai.ClientConfig
758+
}{
759+
{name: "nil config", cc: nil},
760+
{name: "empty config", cc: &genai.ClientConfig{}},
761+
} {
762+
t.Run(test.name, func(t *testing.T) {
763+
client, err := NewGenAIClient(ctx, test.cc)
764+
if err != nil {
765+
t.Fatalf("NewGenAIClient() failed unexpectedly, err: %v", err)
766+
}
767+
if client == nil {
768+
t.Error("client must not be nil")
769+
}
770+
})
771+
}
772+
}
773+
774+
func TestNewGenAIClientErrors(t *testing.T) {
775+
ctx := context.Background()
776+
for _, test := range []struct {
777+
name string
778+
cc *genai.ClientConfig
779+
}{
780+
{name: "gemini backend", cc: &genai.ClientConfig{Backend: genai.BackendGeminiAPI}},
781+
} {
782+
t.Run(test.name, func(t *testing.T) {
783+
if _, err := NewGenAIClient(ctx, test.cc); err == nil {
784+
t.Error("wants error, but got nil")
785+
}
786+
})
787+
}
788+
}

0 commit comments

Comments
 (0)