forked from awsdocs/aws-doc-sdk-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinvoke_model_test.go
More file actions
115 lines (93 loc) · 2.8 KB
/
invoke_model_test.go
File metadata and controls
115 lines (93 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Unit tests for the bedrock runtime actions.
package actions
import (
"context"
"encoding/json"
"log"
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/awsdocs/aws-doc-sdk-examples/gov2/bedrock-runtime/stubs"
"github.com/awsdocs/aws-doc-sdk-examples/gov2/testtools"
)
const CLAUDE_MODEL_ID = "anthropic.claude-v2"
const TITAN_IMAGE_MODEL_ID = "amazon.titan-image-generator-v2:0"
const prompt = "A test prompt"
func CallInvokeModelActions(sdkConfig aws.Config) {
defer func() {
if r := recover(); r != nil {
log.Println(r)
}
}()
client := bedrockruntime.NewFromConfig(sdkConfig)
wrapper := InvokeModelWrapper{client}
ctx := context.Background()
claudeCompletion, err := wrapper.InvokeClaude(ctx, prompt)
if err != nil {
panic(err)
}
log.Println(claudeCompletion)
seed := int64(0)
titanImageCompletion, err := wrapper.InvokeTitanImage(ctx, prompt, seed)
if err != nil {
panic(err)
}
log.Println(titanImageCompletion)
log.Printf("Thanks for watching!")
}
func TestInvokeModels(t *testing.T) {
scenTest := InvokeModelActionsTest{}
testtools.RunScenarioTests(&scenTest, t)
}
type InvokeModelActionsTest struct{}
func (scenTest *InvokeModelActionsTest) SetupDataAndStubs() []testtools.Stub {
var stubList []testtools.Stub
stubList = append(stubList, stubInvokeModel(CLAUDE_MODEL_ID))
stubList = append(stubList, stubInvokeModel(TITAN_IMAGE_MODEL_ID))
return stubList
}
func (scenTest *InvokeModelActionsTest) RunSubTest(stubber *testtools.AwsmStubber) {
CallInvokeModelActions(*stubber.SdkConfig)
}
func (scenTest *InvokeModelActionsTest) Cleanup() {}
func stubInvokeModel(modelId string) testtools.Stub {
var request []byte
var response []byte
switch modelId {
case CLAUDE_MODEL_ID:
request, _ = json.Marshal(ClaudeRequest{
Prompt: "Human: " + prompt + "\n\nAssistant:",
MaxTokensToSample: 200,
Temperature: 0.5,
StopSequences: []string{"\n\nHuman:"},
})
response, _ = json.Marshal(ClaudeResponse{
Completion: "A fake response",
})
case TITAN_IMAGE_MODEL_ID:
request, _ = json.Marshal(TitanImageRequest{
TaskType: "TEXT_IMAGE",
TextToImageParams: TextToImageParams{
Text: prompt,
},
ImageGenerationConfig: ImageGenerationConfig{
NumberOfImages: 1,
Quality: "standard",
CfgScale: 8.0,
Height: 512,
Width: 512,
Seed: 0,
},
})
response, _ = json.Marshal(TitanImageResponse{
Images: []string{"FakeBase64String=="},
})
default:
return testtools.Stub{}
}
return stubs.StubInvokeModel(stubs.StubInvokeModelParams{
Request: request, Response: response, ModelId: modelId, RaiseErr: nil,
})
}