Skip to content

Commit 3e6f1e8

Browse files
chore(vertexai): Test the NewGenAIClient function
FUTURE_COPYBARA_INTEGRATE_REVIEW=#14351 from renovate-bot:renovate/main-deps 2e7e530 PiperOrigin-RevId: 895405821
1 parent 5e51e19 commit 3e6f1e8

2 files changed

Lines changed: 59 additions & 0 deletions

File tree

vertexai/genai/client.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,10 +461,19 @@ func float32pToInt32p(x *float32) *int32 {
461461

462462
// NewGenAIClient creates a new Google Vertex AI client and configures the the GenAI components.
463463
func NewGenAIClient(ctx context.Context, cc *genai.ClientConfig) (*Client, error) {
464+
if cc == nil {
465+
cc = &genai.ClientConfig{Backend: genai.BackendVertexAI}
466+
}
467+
if cc.Backend == genai.BackendUnspecified {
468+
cc.Backend = genai.BackendVertexAI
469+
}
464470
ac, err := genai.NewInternalAPIClient(ctx, cc)
465471
if err != nil {
466472
return nil, err
467473
}
474+
if ac.ClientConfig().Backend != genai.BackendVertexAI {
475+
return nil, fmt.Errorf("only Vertex AI backend is supported")
476+
}
468477
return &Client{
469478
AgentEngines: &clientAgentEngines{
470479
AgentEngines: AgentEngines{apiClient: ac},

vertexai/genai/client_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"testing"
2929

3030
"google.golang.org/api/iterator"
31+
"google.golang.org/genai"
3132
)
3233

3334
const defaultModel = "gemini-1.0-pro"
@@ -739,3 +740,52 @@ func printResponse(resp *GenerateContentResponse) {
739740
}
740741
fmt.Println("---")
741742
}
743+
744+
func TestNewGenAIClient(t *testing.T) {
745+
ctx := context.Background()
746+
projectKey := "GOOGLE_CLOUD_PROJECT"
747+
locationKey := "GOOGLE_CLOUD_LOCATION"
748+
originalProjectValue, _ := os.LookupEnv(projectKey)
749+
originalLocationValue, _ := os.LookupEnv(locationKey)
750+
os.Setenv(projectKey, "test-gcp-project")
751+
os.Setenv(locationKey, "us-central1")
752+
753+
t.Cleanup(func() {
754+
os.Setenv(locationKey, originalLocationValue)
755+
os.Setenv(projectKey, originalProjectValue)
756+
})
757+
758+
for _, test := range []struct {
759+
name string
760+
cc *genai.ClientConfig
761+
}{
762+
{name: "nil config", cc: nil},
763+
{name: "empty config", cc: &genai.ClientConfig{}},
764+
} {
765+
t.Run(test.name, func(t *testing.T) {
766+
client, err := NewGenAIClient(ctx, test.cc)
767+
if err != nil {
768+
t.Fatalf("NewGenAIClient() failed unexpectedly, err: %v", err)
769+
}
770+
if client == nil {
771+
t.Error("client must not be nil")
772+
}
773+
})
774+
}
775+
}
776+
777+
func TestNewGenAIClientErrors(t *testing.T) {
778+
ctx := context.Background()
779+
for _, test := range []struct {
780+
name string
781+
cc *genai.ClientConfig
782+
}{
783+
{name: "gemini backend", cc: &genai.ClientConfig{Backend: genai.BackendGeminiAPI}},
784+
} {
785+
t.Run(test.name, func(t *testing.T) {
786+
if _, err := NewGenAIClient(ctx, test.cc); err == nil {
787+
t.Error("wants error, but got nil")
788+
}
789+
})
790+
}
791+
}

0 commit comments

Comments
 (0)