Skip to content

Commit ae914a2

Browse files
committed
only create use-case and model agreement once per account
1 parent 552ace4 commit ae914a2

File tree

1 file changed

+29
-26
lines changed
  • v2/internal/attacktechniques/aws/impact/bedrock-invoke-model

1 file changed

+29
-26
lines changed

v2/internal/attacktechniques/aws/impact/bedrock-invoke-model/main.go

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -108,43 +108,46 @@ func findAndEnableModelToUse(awsConnection aws.Config) (string, error) {
108108
return "", errors.New("Bedrock model " + modelToUse + " is not available in the current region. Try setting AWS_REGION=us-east-1 instead")
109109
}
110110
if availability.EntitlementAvailability != "AVAILABLE" {
111-
err := enableModel(awsConnection, modelToUse)
111+
err := enableModel(awsConnection, modelToUse, availability)
112112
if err != nil {
113113
return "", fmt.Errorf("unable to enable model: %w", err)
114114
}
115115
}
116116
return modelToUse, nil
117117
}
118118

119-
func enableModel(awsConnection aws.Config, modelId string) error {
119+
func enableModel(awsConnection aws.Config, modelId string, availability *BedrockModelAvailability) error {
120120
log.Println("Enabling model " + modelId)
121121

122-
// Need to create a use-case request for Anthropic models, c.f.
123-
if strings.HasPrefix(modelId, "anthropic.") {
124-
_, err := PutUseCaseForModelAccess(awsConnection, &BedrockUseCaseRequest{
125-
CompanyName: "test",
126-
CompanyWebsite: "https://test.com",
127-
IntendedUsers: "0",
128-
IndustryOption: "Government",
129-
OtherIndustryOption: "",
130-
UseCases: "None of the Above. test",
131-
})
132-
if err != nil {
133-
return fmt.Errorf("unable to put use case for model access: %w", err)
122+
// Need to create a use-case request for Anthropic models
123+
// AgreementAvailability is account-wide (not region-specific). If a use-case was put for the model once in the account, it will be available in all regions, and we'll only need to call PutFoundationModelEntitlement in further region
124+
if availability.AgreementAvailability.Status != "AVAILABLE" {
125+
if strings.HasPrefix(modelId, "anthropic.") && availability.AgreementAvailability.Status != "AVAILABLE" {
126+
_, err := PutUseCaseForModelAccess(awsConnection, &BedrockUseCaseRequest{
127+
CompanyName: "test",
128+
CompanyWebsite: "https://test.com",
129+
IntendedUsers: "0",
130+
IndustryOption: "Government",
131+
OtherIndustryOption: "",
132+
UseCases: "None of the Above. test",
133+
})
134+
if err != nil {
135+
return fmt.Errorf("unable to put use case for model access: %w", err)
136+
}
134137
}
135-
}
136138

137-
offerToken, err := ListFoundationModelAgreementOffers(awsConnection, modelId)
138-
if err != nil {
139-
return fmt.Errorf("unable to list agreement offers: %w", err)
140-
}
139+
offerToken, err := ListFoundationModelAgreementOffers(awsConnection, modelId)
140+
if err != nil {
141+
return fmt.Errorf("unable to list agreement offers: %w", err)
142+
}
141143

142-
_, err = CreateFoundationModelAgreement(awsConnection, modelId, offerToken)
143-
if err != nil {
144-
return fmt.Errorf("unable to create model agreement: %w", err)
144+
_, err = CreateFoundationModelAgreement(awsConnection, modelId, offerToken)
145+
if err != nil {
146+
return fmt.Errorf("unable to create model agreement: %w", err)
147+
}
145148
}
146149

147-
_, err = PutFoundationModelEntitlement(awsConnection, modelId)
150+
_, err := PutFoundationModelEntitlement(awsConnection, modelId)
148151
if err != nil {
149152
return fmt.Errorf("unable to put model entitlement: %w", err)
150153
}
@@ -156,7 +159,7 @@ func enableModel(awsConnection aws.Config, modelId string) error {
156159

157160
}
158161

159-
type BedrockModelAvailabilityResponse struct {
162+
type BedrockModelAvailability struct {
160163
RegionAvailability string `json:"regionAvailability"`
161164
AgreementAvailability struct {
162165
Status string `json:"status"`
@@ -166,7 +169,7 @@ type BedrockModelAvailabilityResponse struct {
166169

167170
// GetFoundationModelAvailability retrieves model availability information.
168171
// Note: At the time of writing, this function is not available in the AWS SDK for Go v2
169-
func GetFoundationModelAvailability(cfg aws.Config, model string) (*BedrockModelAvailabilityResponse, error) {
172+
func GetFoundationModelAvailability(cfg aws.Config, model string) (*BedrockModelAvailability, error) {
170173
region := cfg.Region
171174

172175
host := fmt.Sprintf("bedrock.%s.amazonaws.com", region)
@@ -205,7 +208,7 @@ func GetFoundationModelAvailability(cfg aws.Config, model string) (*BedrockModel
205208
}
206209

207210
fmt.Println(string(body))
208-
var result BedrockModelAvailabilityResponse
211+
var result BedrockModelAvailability
209212
if err := json.Unmarshal(body, &result); err != nil {
210213
return nil, errors.New("Error unmarshalling response body: " + err.Error())
211214
}

0 commit comments

Comments
 (0)