@@ -29,6 +29,7 @@ import (
2929 "github.com/Azure/azure-sdk-for-go/sdk/azcore"
3030 "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
3131 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
32+ "github.com/Azure/karpenter-provider-azure/pkg/auth"
3233 "github.com/Azure/karpenter-provider-azure/pkg/providers/instance"
3334 "github.com/samber/lo"
3435)
@@ -74,6 +75,7 @@ type VirtualMachinesAPI struct {
7475 // TODO: document the implications of embedding vs. not embedding the interface here
7576 // instance.VirtualMachinesAPI // - this is the interface we are mocking.
7677 VirtualMachinesBehavior
78+ AuxiliaryTokenPolicy * auth.AuxiliaryTokenPolicy
7779}
7880
7981// Reset must be called between tests otherwise tests will pollute each other.
@@ -88,7 +90,24 @@ func (c *VirtualMachinesAPI) Reset() {
8890 })
8991}
9092
91- func (c * VirtualMachinesAPI ) BeginCreateOrUpdate (_ context.Context , resourceGroupName string , vmName string , parameters armcompute.VirtualMachine , options * armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions ) (* runtime.Poller [armcompute.VirtualMachinesClientCreateOrUpdateResponse ], error ) {
93+ // UseAuxiliaryTokenPolicy simulates AuxiliaryTokenPolicy.Do() method being called at the beginning of each API call
94+ // This is useful for testing scenarios where the auxiliary token is required for the API call to succeed.
95+ // If the AuxiliaryTokenPolicy is not set (USE_SIG: false), this method does nothing and returns nil.
96+ func (c * VirtualMachinesAPI ) UseAuxiliaryTokenPolicy () error {
97+ if c .AuxiliaryTokenPolicy != nil {
98+ request , _ := runtime .NewRequest (context .Background (), "GET" , "http://example.com" )
99+ if _ , err := c .AuxiliaryTokenPolicy .Do (request ); err != nil {
100+ // req.Next() returns this if there are no more policies.
101+ if err .Error () == "no more policies" {
102+ return nil
103+ }
104+ return err
105+ }
106+ }
107+ return nil
108+ }
109+
110+ func (c * VirtualMachinesAPI ) BeginCreateOrUpdate (ctx context.Context , resourceGroupName string , vmName string , parameters armcompute.VirtualMachine , options * armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions ) (* runtime.Poller [armcompute.VirtualMachinesClientCreateOrUpdateResponse ], error ) {
92111 // gather input parameters (may get rid of this with multiple mocked function signatures to reflect common patterns)
93112 input := & VirtualMachineCreateOrUpdateInput {
94113 ResourceGroupName : resourceGroupName ,
@@ -99,6 +118,9 @@ func (c *VirtualMachinesAPI) BeginCreateOrUpdate(_ context.Context, resourceGrou
99118 // BeginCreateOrUpdate should fail, if the vm exists in the cache, and we are attempting to change properties for zone
100119
101120 return c .VirtualMachineCreateOrUpdateBehavior .Invoke (input , func (input * VirtualMachineCreateOrUpdateInput ) (* armcompute.VirtualMachinesClientCreateOrUpdateResponse , error ) {
121+ if err := c .UseAuxiliaryTokenPolicy (); err != nil {
122+ return nil , getAuthTokenError (err )
123+ }
102124 //if input.ResourceGroupName == "" {
103125 // return nil, errors.New("ResourceGroupName is required")
104126 //}
@@ -160,6 +182,9 @@ func (c *VirtualMachinesAPI) BeginUpdate(_ context.Context, resourceGroupName st
160182 Options : options ,
161183 }
162184 return c .VirtualMachineUpdateBehavior .Invoke (input , func (input * VirtualMachineUpdateInput ) (* armcompute.VirtualMachinesClientUpdateResponse , error ) {
185+ if err := c .UseAuxiliaryTokenPolicy (); err != nil {
186+ return nil , getAuthTokenError (err )
187+ }
163188 id := MkVMID (input .ResourceGroupName , input .VMName )
164189
165190 instance , ok := c .Instances .Load (id )
@@ -203,6 +228,9 @@ func (c *VirtualMachinesAPI) Get(_ context.Context, resourceGroupName string, vm
203228 Options : options ,
204229 }
205230 return c .VirtualMachineGetBehavior .Invoke (input , func (input * VirtualMachineGetInput ) (armcompute.VirtualMachinesClientGetResponse , error ) {
231+ if err := c .UseAuxiliaryTokenPolicy (); err != nil {
232+ return armcompute.VirtualMachinesClientGetResponse {}, getAuthTokenError (err )
233+ }
206234 instance , ok := c .Instances .Load (MkVMID (input .ResourceGroupName , input .VMName ))
207235 if ! ok {
208236 return armcompute.VirtualMachinesClientGetResponse {}, & azcore.ResponseError {ErrorCode : errors .ResourceNotFound }
@@ -220,6 +248,9 @@ func (c *VirtualMachinesAPI) BeginDelete(_ context.Context, resourceGroupName st
220248 Options : options ,
221249 }
222250 return c .VirtualMachineDeleteBehavior .Invoke (input , func (input * VirtualMachineDeleteInput ) (* armcompute.VirtualMachinesClientDeleteResponse , error ) {
251+ if err := c .UseAuxiliaryTokenPolicy (); err != nil {
252+ return & armcompute.VirtualMachinesClientDeleteResponse {}, getAuthTokenError (err )
253+ }
223254 c .Instances .Delete (MkVMID (input .ResourceGroupName , input .VMName ))
224255 return & armcompute.VirtualMachinesClientDeleteResponse {}, nil
225256 })
@@ -233,3 +264,12 @@ func MkVMID(resourceGroupName string, vmName string) string {
233264 const idFormat = "/subscriptions/subscriptionID/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s"
234265 return fmt .Sprintf (idFormat , resourceGroupName , vmName )
235266}
267+
268+ func getAuthTokenError (err error ) * azcore.ResponseError {
269+ return & azcore.ResponseError {
270+ ErrorCode : "AuthenticationFailed" ,
271+ RawResponse : & http.Response {
272+ Body : createSDKErrorBody ("AuthenticationFailed" , err .Error ()),
273+ },
274+ }
275+ }
0 commit comments