@@ -108,43 +108,46 @@ func findAndEnableModelToUse(awsConnection aws.Config) (string, error) {
108108 return "" , errors .New ("Bedrock model " + modelToUse + " is not available in the current region. Try setting AWS_REGION=us-east-1 instead" )
109109 }
110110 if availability .EntitlementAvailability != "AVAILABLE" {
111- err := enableModel (awsConnection , modelToUse )
111+ err := enableModel (awsConnection , modelToUse , availability )
112112 if err != nil {
113113 return "" , fmt .Errorf ("unable to enable model: %w" , err )
114114 }
115115 }
116116 return modelToUse , nil
117117}
118118
119- func enableModel (awsConnection aws.Config , modelId string ) error {
119+ func enableModel (awsConnection aws.Config , modelId string , availability * BedrockModelAvailability ) error {
120120 log .Println ("Enabling model " + modelId )
121121
122- // Need to create a use-case request for Anthropic models, c.f.
123- if strings .HasPrefix (modelId , "anthropic." ) {
124- _ , err := PutUseCaseForModelAccess (awsConnection , & BedrockUseCaseRequest {
125- CompanyName : "test" ,
126- CompanyWebsite : "https://test.com" ,
127- IntendedUsers : "0" ,
128- IndustryOption : "Government" ,
129- OtherIndustryOption : "" ,
130- UseCases : "None of the Above. test" ,
131- })
132- if err != nil {
133- return fmt .Errorf ("unable to put use case for model access: %w" , err )
122+ // Need to create a use-case request for Anthropic models
123+ // AgreementAvailability is account-wide (not region-specific). If a use-case was put for the model once in the account, it will be available in all regions, and we'll only need to call PutFoundationModelEntitlement in further region
124+ if availability .AgreementAvailability .Status != "AVAILABLE" {
125+ if strings .HasPrefix (modelId , "anthropic." ) && availability .AgreementAvailability .Status != "AVAILABLE" {
126+ _ , err := PutUseCaseForModelAccess (awsConnection , & BedrockUseCaseRequest {
127+ CompanyName : "test" ,
128+ CompanyWebsite : "https://test.com" ,
129+ IntendedUsers : "0" ,
130+ IndustryOption : "Government" ,
131+ OtherIndustryOption : "" ,
132+ UseCases : "None of the Above. test" ,
133+ })
134+ if err != nil {
135+ return fmt .Errorf ("unable to put use case for model access: %w" , err )
136+ }
134137 }
135- }
136138
137- offerToken , err := ListFoundationModelAgreementOffers (awsConnection , modelId )
138- if err != nil {
139- return fmt .Errorf ("unable to list agreement offers: %w" , err )
140- }
139+ offerToken , err := ListFoundationModelAgreementOffers (awsConnection , modelId )
140+ if err != nil {
141+ return fmt .Errorf ("unable to list agreement offers: %w" , err )
142+ }
141143
142- _ , err = CreateFoundationModelAgreement (awsConnection , modelId , offerToken )
143- if err != nil {
144- return fmt .Errorf ("unable to create model agreement: %w" , err )
144+ _ , err = CreateFoundationModelAgreement (awsConnection , modelId , offerToken )
145+ if err != nil {
146+ return fmt .Errorf ("unable to create model agreement: %w" , err )
147+ }
145148 }
146149
147- _ , err = PutFoundationModelEntitlement (awsConnection , modelId )
150+ _ , err : = PutFoundationModelEntitlement (awsConnection , modelId )
148151 if err != nil {
149152 return fmt .Errorf ("unable to put model entitlement: %w" , err )
150153 }
@@ -156,7 +159,7 @@ func enableModel(awsConnection aws.Config, modelId string) error {
156159
157160}
158161
159- type BedrockModelAvailabilityResponse struct {
162+ type BedrockModelAvailability struct {
160163 RegionAvailability string `json:"regionAvailability"`
161164 AgreementAvailability struct {
162165 Status string `json:"status"`
@@ -166,7 +169,7 @@ type BedrockModelAvailabilityResponse struct {
166169
167170// GetFoundationModelAvailability retrieves model availability information.
168171// Note: At the time of writing, this function is not available in the AWS SDK for Go v2
169- func GetFoundationModelAvailability (cfg aws.Config , model string ) (* BedrockModelAvailabilityResponse , error ) {
172+ func GetFoundationModelAvailability (cfg aws.Config , model string ) (* BedrockModelAvailability , error ) {
170173 region := cfg .Region
171174
172175 host := fmt .Sprintf ("bedrock.%s.amazonaws.com" , region )
@@ -205,7 +208,7 @@ func GetFoundationModelAvailability(cfg aws.Config, model string) (*BedrockModel
205208 }
206209
207210 fmt .Println (string (body ))
208- var result BedrockModelAvailabilityResponse
211+ var result BedrockModelAvailability
209212 if err := json .Unmarshal (body , & result ); err != nil {
210213 return nil , errors .New ("Error unmarshalling response body: " + err .Error ())
211214 }
0 commit comments