Skip to content

Commit 5fdfdbd

Browse files
BenjaminKazemicopybara-github
authored andcommitted
chore(vertexai): Test the NewGenAIClient function
FUTURE_COPYBARA_INTEGRATE_REVIEW=#14351 from renovate-bot:renovate/main-deps 2e7e530 PiperOrigin-RevId: 895405821
1 parent 5e51e19 commit 5fdfdbd

2 files changed

Lines changed: 47 additions & 0 deletions

File tree

vertexai/genai/client.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,10 +461,19 @@ func float32pToInt32p(x *float32) *int32 {
461461

462462
// NewGenAIClient creates a new Google Vertex AI client and configures the the GenAI components.
463463
func NewGenAIClient(ctx context.Context, cc *genai.ClientConfig) (*Client, error) {
464+
if cc == nil {
465+
cc = &genai.ClientConfig{Backend: genai.BackendVertexAI}
466+
}
467+
if cc.Backend == genai.BackendUnspecified {
468+
cc.Backend = genai.BackendVertexAI
469+
}
464470
ac, err := genai.NewInternalAPIClient(ctx, cc)
465471
if err != nil {
466472
return nil, err
467473
}
474+
if ac.ClientConfig().Backend != genai.BackendVertexAI {
475+
return nil, fmt.Errorf("only Vertex AI backend is supported")
476+
}
468477
return &Client{
469478
AgentEngines: &clientAgentEngines{
470479
AgentEngines: AgentEngines{apiClient: ac},

vertexai/genai/client_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"testing"
2929

3030
"google.golang.org/api/iterator"
31+
"google.golang.org/genai"
3132
)
3233

3334
const defaultModel = "gemini-1.0-pro"
@@ -739,3 +740,40 @@ func printResponse(resp *GenerateContentResponse) {
739740
}
740741
fmt.Println("---")
741742
}
743+
744+
func TestNewGenAIClient(t *testing.T) {
745+
ctx := context.Background()
746+
for _, test := range []struct {
747+
name string
748+
cc *genai.ClientConfig
749+
}{
750+
{name: "nil config", cc: nil},
751+
{name: "empty config", cc: &genai.ClientConfig{}},
752+
} {
753+
t.Run(test.name, func(t *testing.T) {
754+
client, err := NewGenAIClient(ctx, test.cc)
755+
if err != nil {
756+
t.Fatalf("NewGenAIClient() failed unexpectedly, err: %v", err)
757+
}
758+
if client == nil {
759+
t.Error("client must not be nil")
760+
}
761+
})
762+
}
763+
}
764+
765+
func TestNewGenAIClientErrors(t *testing.T) {
766+
ctx := context.Background()
767+
for _, test := range []struct {
768+
name string
769+
cc *genai.ClientConfig
770+
}{
771+
{name: "gemini backend", cc: &genai.ClientConfig{Backend: genai.BackendGeminiAPI}},
772+
} {
773+
t.Run(test.name, func(t *testing.T) {
774+
if _, err := NewGenAIClient(ctx, test.cc); err == nil {
775+
t.Error("wants error, but got nil")
776+
}
777+
})
778+
}
779+
}

0 commit comments

Comments
 (0)