Skip to content

Commit 2f245bb

Browse files
authored
fix: auxiliary token refresh and use sig in acceptance tests (#991)
* initial implementation, start of tests * Add tests + update test env to use sig * ci fixes + small improvements * Update tests to directly call UseAuxiliaryTokenPolicy, incorporate small feedback
1 parent 7d3866f commit 2f245bb

File tree

13 files changed

+266
-76
lines changed

13 files changed

+266
-76
lines changed

pkg/auth/auxiliarytoken.go

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ limitations under the License.
1717
package auth
1818

1919
import (
20-
"context"
2120
"encoding/json"
2221
"fmt"
2322
"net/http"
23+
"sync"
24+
"time"
2425

2526
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
2627
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
@@ -36,22 +37,47 @@ var _ policy.Policy = &AuxiliaryTokenPolicy{}
3637
// AuxiliaryTokenPolicy provides a custom policy used to authenticate
3738
// with shared node image galleries.
3839
type AuxiliaryTokenPolicy struct {
39-
Token string
40+
Token azcore.AccessToken
41+
url string
42+
scope string
43+
client AuxiliaryTokenServer
44+
lock sync.Mutex
45+
}
46+
47+
func (p *AuxiliaryTokenPolicy) GetAuxiliaryToken() error {
48+
p.lock.Lock()
49+
defer p.lock.Unlock()
50+
// If the token is uninitialized or close to expiration, fetch a new one
51+
currentTime := time.Now()
52+
if p.Token.ExpiresOn.IsZero() || p.Token.RefreshOn.Before(currentTime) || p.Token.ExpiresOn.Before(currentTime.Add(5*time.Minute)) {
53+
newToken, err := getAuxiliaryToken(p.client, p.url, p.scope)
54+
if err != nil {
55+
return err
56+
}
57+
p.Token = newToken
58+
}
59+
return nil
4060
}
4161

4262
func (p *AuxiliaryTokenPolicy) Do(req *policy.Request) (*http.Response, error) {
43-
req.Raw().Header.Add("x-ms-authorization-auxiliary", "Bearer "+p.Token)
63+
err := p.GetAuxiliaryToken()
64+
if err != nil {
65+
log.FromContext(req.Raw().Context()).Error(err, "Failed to get auxiliary token")
66+
return nil, err
67+
}
68+
req.Raw().Header.Add("x-ms-authorization-auxiliary", "Bearer "+p.Token.Token)
4469
return req.Next()
4570
}
4671

47-
func NewAuxiliaryTokenPolicy(ctx context.Context, client AuxiliaryTokenServer, url string, scope string) (*AuxiliaryTokenPolicy, error) {
48-
token, err := getAuxiliaryToken(client, url, scope)
49-
if err != nil {
50-
return nil, fmt.Errorf("failed to get auxiliary token: %w", err)
72+
func NewAuxiliaryTokenPolicy(client AuxiliaryTokenServer, url string, scope string) *AuxiliaryTokenPolicy {
73+
auxPolicy := AuxiliaryTokenPolicy{
74+
Token: azcore.AccessToken{},
75+
url: url,
76+
scope: scope,
77+
client: client,
78+
lock: sync.Mutex{},
5179
}
52-
auxPolicy := AuxiliaryTokenPolicy{Token: token.Token}
53-
log.FromContext(ctx).V(1).Info("Will use auxiliary token policy for creating virtual machines")
54-
return &auxPolicy, nil
80+
return &auxPolicy
5581
}
5682

5783
func getAuxiliaryToken(client AuxiliaryTokenServer, url string, scope string) (azcore.AccessToken, error) {

pkg/cloudprovider/suite_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ import (
5353
)
5454

5555
var ctx context.Context
56+
var testOptions *options.Options
5657
var stop context.CancelFunc
5758
var env *coretest.Environment
5859
var azureEnv *test.Environment
@@ -75,7 +76,8 @@ func TestCloudProvider(t *testing.T) {
7576
var _ = BeforeSuite(func() {
7677
env = coretest.NewEnvironment(coretest.WithCRDs(apis.CRDs...), coretest.WithCRDs(v1alpha1.CRDs...), coretest.WithFieldIndexers(coretest.NodeProviderIDFieldIndexer(ctx)))
7778
ctx = coreoptions.ToContext(ctx, coretest.Options())
78-
ctx = options.ToContext(ctx, test.Options())
79+
testOptions = test.Options()
80+
ctx = options.ToContext(ctx, testOptions)
7981
ctx, stop = context.WithCancel(ctx)
8082
azureEnv = test.NewEnvironment(ctx, env)
8183
fakeClock = clock.NewFakeClock(time.Now())
@@ -91,11 +93,12 @@ var _ = AfterSuite(func() {
9193
})
9294

9395
var _ = BeforeEach(func() {
96+
testOptions = test.Options()
9497
ctx = coreoptions.ToContext(ctx, coretest.Options())
95-
ctx = options.ToContext(ctx, test.Options())
98+
ctx = options.ToContext(ctx, testOptions)
9699

97100
nodeClass = test.AKSNodeClass()
98-
test.ApplyDefaultStatus(nodeClass, env)
101+
test.ApplyDefaultStatus(nodeClass, env, testOptions.UseSIG)
99102

100103
nodePool = coretest.NodePool(karpv1.NodePool{
101104
Spec: karpv1.NodePoolSpec{

pkg/controllers/nodeclaim/garbagecollection/suite_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ import (
5555
)
5656

5757
var ctx context.Context
58+
var testOptions *options.Options
5859
var env *coretest.Environment
5960
var azureEnv *test.Environment
6061
var fakeClock *clock.FakeClock
@@ -74,7 +75,8 @@ func TestAPIs(t *testing.T) {
7475

7576
var _ = BeforeSuite(func() {
7677
ctx = coreoptions.ToContext(ctx, coretest.Options())
77-
ctx = options.ToContext(ctx, test.Options())
78+
testOptions = test.Options()
79+
ctx = options.ToContext(ctx, testOptions)
7880
env = coretest.NewEnvironment(coretest.WithCRDs(apis.CRDs...), coretest.WithCRDs(v1alpha1.CRDs...))
7981
// ctx, stop = context.WithCancel(ctx)
8082
azureEnv = test.NewEnvironment(ctx, env)
@@ -94,7 +96,7 @@ var _ = AfterSuite(func() {
9496

9597
var _ = BeforeEach(func() {
9698
nodeClass = test.AKSNodeClass()
97-
test.ApplyDefaultStatus(nodeClass, env)
99+
test.ApplyDefaultStatus(nodeClass, env, testOptions.UseSIG)
98100

99101
nodePool = coretest.NodePool(karpv1.NodePool{
100102
Spec: karpv1.NodePoolSpec{

pkg/fake/auxiliarytokenserver.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,26 @@ var _ auth.AuxiliaryTokenServer = &AuxiliaryTokenServer{}
4040

4141
type AuxiliaryTokenServer struct {
4242
AuxiliaryTokenBehavior
43+
Token azcore.AccessToken
44+
}
45+
46+
// NewAuxiliaryTokenServer creates a new AuxiliaryTokenServer with the given token.
47+
func NewAuxiliaryTokenServer(token string, expiresOn time.Time, refreshOn time.Time) *AuxiliaryTokenServer {
48+
return &AuxiliaryTokenServer{
49+
Token: azcore.AccessToken{
50+
Token: token,
51+
ExpiresOn: expiresOn,
52+
RefreshOn: refreshOn,
53+
},
54+
}
55+
}
56+
57+
func (c *AuxiliaryTokenServer) SetToken(token string, expiresOn time.Time, refreshOn time.Time) {
58+
c.Token = azcore.AccessToken{
59+
Token: token,
60+
ExpiresOn: expiresOn,
61+
RefreshOn: refreshOn,
62+
}
4363
}
4464

4565
// Reset must be called between tests otherwise tests will pollute each other.
@@ -61,10 +81,7 @@ func (c *AuxiliaryTokenServer) Do(req *http.Request) (*http.Response, error) {
6181
return resp, nil
6282
}
6383

64-
token := azcore.AccessToken{
65-
Token: "fake-token",
66-
ExpiresOn: time.Now().Add(1 * time.Hour),
67-
}
84+
token := c.Token
6885
tokenBytes, _ := json.Marshal(token)
6986
resp.StatusCode = http.StatusOK
7087
resp.Body = io.NopCloser(bytes.NewReader(tokenBytes))

pkg/fake/auxiliarytokenserver_test.go

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,62 +17,55 @@ limitations under the License.
1717
package fake
1818

1919
import (
20-
"context"
20+
"net/http"
21+
"net/url"
2122
"testing"
23+
"time"
2224

2325
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
2426
"github.com/Azure/karpenter-provider-azure/pkg/auth"
25-
armopts "github.com/Azure/karpenter-provider-azure/pkg/utils/opts"
2627
"github.com/stretchr/testify/assert"
2728
)
2829

2930
func Test_AddAuxiliaryTokenPolicyClientOptions(t *testing.T) {
31+
defaultToken := azcore.AccessToken{
32+
Token: "test-token",
33+
ExpiresOn: time.Now().Add(1 * time.Hour),
34+
RefreshOn: time.Now().Add(5 * time.Second),
35+
}
3036
tests := []struct {
31-
name string
32-
expected azcore.AccessToken
33-
wantErr bool
34-
errString string
35-
url string
36-
scope string
37+
name string
38+
userAgent string
39+
statusCode int
3740
}{
3841
{
39-
name: "url is not set",
40-
wantErr: true,
41-
errString: "access token server URL is not set",
42-
url: "",
43-
scope: "anything",
44-
},
45-
{
46-
name: "scope is not set",
47-
wantErr: true,
48-
errString: "access token scope is not set",
49-
url: "anything",
50-
scope: "",
42+
name: "default",
43+
userAgent: auth.GetUserAgentExtension(),
44+
statusCode: http.StatusOK,
5145
},
5246
{
53-
name: "default",
54-
wantErr: false,
55-
url: "http://test-url.com",
56-
scope: "test-scope",
47+
name: "wrong user agent",
48+
userAgent: "wrong-user-agent",
49+
statusCode: http.StatusUnauthorized,
5750
},
5851
}
59-
tokenServer := &AuxiliaryTokenServer{}
52+
tokenServer := &AuxiliaryTokenServer{Token: defaultToken}
6053
for _, tt := range tests {
6154
t.Run(tt.name, func(t *testing.T) {
62-
clientOpts := armopts.DefaultArmOpts()
63-
vmClientOpts := *clientOpts
64-
auxPolicy, err := auth.NewAuxiliaryTokenPolicy(context.Background(), tokenServer, tt.url, tt.scope)
65-
if (err != nil) != tt.wantErr {
66-
t.Errorf("getAuxiliaryToken() error = %v, wantErr: %v", err, tt.wantErr)
67-
return
55+
request := &http.Request{
56+
Method: http.MethodGet,
57+
URL: &url.URL{Path: "/"},
58+
Header: http.Header{
59+
"User-Agent": []string{tt.userAgent},
60+
},
6861
}
69-
vmClientOpts.ClientOptions.PerRetryPolicies = append(vmClientOpts.ClientOptions.PerRetryPolicies, auxPolicy)
70-
if tt.wantErr {
71-
assert.ErrorContains(t, err, tt.errString)
72-
} else {
73-
assert.NotEqual(t, clientOpts.ClientOptions.PerRetryPolicies, vmClientOpts.ClientOptions.PerRetryPolicies)
62+
resp, err := tokenServer.Do(request)
63+
if err != nil {
64+
t.Errorf("Unexpected error %v", err)
65+
return
7466
}
67+
assert.Equal(t, tt.statusCode, resp.StatusCode, "Expected status code %d, got %d", tt.statusCode, resp.StatusCode)
68+
tokenServer.Reset()
7569
})
76-
tokenServer.Reset()
7770
}
7871
}

pkg/fake/virtualmachinesapi.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
3030
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
3131
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
32+
"github.com/Azure/karpenter-provider-azure/pkg/auth"
3233
"github.com/Azure/karpenter-provider-azure/pkg/providers/instance"
3334
"github.com/samber/lo"
3435
)
@@ -74,6 +75,7 @@ type VirtualMachinesAPI struct {
7475
// TODO: document the implications of embedding vs. not embedding the interface here
7576
// instance.VirtualMachinesAPI // - this is the interface we are mocking.
7677
VirtualMachinesBehavior
78+
AuxiliaryTokenPolicy *auth.AuxiliaryTokenPolicy
7779
}
7880

7981
// Reset must be called between tests otherwise tests will pollute each other.
@@ -88,7 +90,24 @@ func (c *VirtualMachinesAPI) Reset() {
8890
})
8991
}
9092

91-
func (c *VirtualMachinesAPI) BeginCreateOrUpdate(_ context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine, options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions) (*runtime.Poller[armcompute.VirtualMachinesClientCreateOrUpdateResponse], error) {
93+
// UseAuxiliaryTokenPolicy simulates AuxiliaryTokenPolicy.Do() method being called at the beginning of each API call
94+
// This is useful for testing scenarios where the auxiliary token is required for the API call to succeed.
95+
// If the AuxiliaryTokenPolicy is not set (USE_SIG: false), this method does nothing and returns nil.
96+
func (c *VirtualMachinesAPI) UseAuxiliaryTokenPolicy() error {
97+
if c.AuxiliaryTokenPolicy != nil {
98+
request, _ := runtime.NewRequest(context.Background(), "GET", "http://example.com")
99+
if _, err := c.AuxiliaryTokenPolicy.Do(request); err != nil {
100+
// req.Next() returns this if there are no more policies.
101+
if err.Error() == "no more policies" {
102+
return nil
103+
}
104+
return err
105+
}
106+
}
107+
return nil
108+
}
109+
110+
func (c *VirtualMachinesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine, options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions) (*runtime.Poller[armcompute.VirtualMachinesClientCreateOrUpdateResponse], error) {
92111
// gather input parameters (may get rid of this with multiple mocked function signatures to reflect common patterns)
93112
input := &VirtualMachineCreateOrUpdateInput{
94113
ResourceGroupName: resourceGroupName,
@@ -99,6 +118,9 @@ func (c *VirtualMachinesAPI) BeginCreateOrUpdate(_ context.Context, resourceGrou
99118
// BeginCreateOrUpdate should fail, if the vm exists in the cache, and we are attempting to change properties for zone
100119

101120
return c.VirtualMachineCreateOrUpdateBehavior.Invoke(input, func(input *VirtualMachineCreateOrUpdateInput) (*armcompute.VirtualMachinesClientCreateOrUpdateResponse, error) {
121+
if err := c.UseAuxiliaryTokenPolicy(); err != nil {
122+
return nil, getAuthTokenError(err)
123+
}
102124
//if input.ResourceGroupName == "" {
103125
// return nil, errors.New("ResourceGroupName is required")
104126
//}
@@ -160,6 +182,9 @@ func (c *VirtualMachinesAPI) BeginUpdate(_ context.Context, resourceGroupName st
160182
Options: options,
161183
}
162184
return c.VirtualMachineUpdateBehavior.Invoke(input, func(input *VirtualMachineUpdateInput) (*armcompute.VirtualMachinesClientUpdateResponse, error) {
185+
if err := c.UseAuxiliaryTokenPolicy(); err != nil {
186+
return nil, getAuthTokenError(err)
187+
}
163188
id := MkVMID(input.ResourceGroupName, input.VMName)
164189

165190
instance, ok := c.Instances.Load(id)
@@ -203,6 +228,9 @@ func (c *VirtualMachinesAPI) Get(_ context.Context, resourceGroupName string, vm
203228
Options: options,
204229
}
205230
return c.VirtualMachineGetBehavior.Invoke(input, func(input *VirtualMachineGetInput) (armcompute.VirtualMachinesClientGetResponse, error) {
231+
if err := c.UseAuxiliaryTokenPolicy(); err != nil {
232+
return armcompute.VirtualMachinesClientGetResponse{}, getAuthTokenError(err)
233+
}
206234
instance, ok := c.Instances.Load(MkVMID(input.ResourceGroupName, input.VMName))
207235
if !ok {
208236
return armcompute.VirtualMachinesClientGetResponse{}, &azcore.ResponseError{ErrorCode: errors.ResourceNotFound}
@@ -220,6 +248,9 @@ func (c *VirtualMachinesAPI) BeginDelete(_ context.Context, resourceGroupName st
220248
Options: options,
221249
}
222250
return c.VirtualMachineDeleteBehavior.Invoke(input, func(input *VirtualMachineDeleteInput) (*armcompute.VirtualMachinesClientDeleteResponse, error) {
251+
if err := c.UseAuxiliaryTokenPolicy(); err != nil {
252+
return &armcompute.VirtualMachinesClientDeleteResponse{}, getAuthTokenError(err)
253+
}
223254
c.Instances.Delete(MkVMID(input.ResourceGroupName, input.VMName))
224255
return &armcompute.VirtualMachinesClientDeleteResponse{}, nil
225256
})
@@ -233,3 +264,12 @@ func MkVMID(resourceGroupName string, vmName string) string {
233264
const idFormat = "/subscriptions/subscriptionID/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s"
234265
return fmt.Sprintf(idFormat, resourceGroupName, vmName)
235266
}
267+
268+
func getAuthTokenError(err error) *azcore.ResponseError {
269+
return &azcore.ResponseError{
270+
ErrorCode: "AuthenticationFailed",
271+
RawResponse: &http.Response{
272+
Body: createSDKErrorBody("AuthenticationFailed", err.Error()),
273+
},
274+
}
275+
}

pkg/providers/imagefamily/nodeimage_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ func getExpectedTestSIGImages(imageFamily string, version string, kubernetesVers
9191

9292
var _ = Describe("NodeImageProvider tests", func() {
9393
var (
94+
testOptions *options.Options
9495
communityImageVersionsAPI *fake.CommunityGalleryImageVersionsAPI
9596

9697
nodeImageProvider imagefamily.NodeImageProvider
@@ -100,7 +101,8 @@ var _ = Describe("NodeImageProvider tests", func() {
100101

101102
BeforeEach(func() {
102103
ctx = coreoptions.ToContext(ctx, coretest.Options())
103-
ctx = options.ToContext(ctx, test.Options())
104+
testOptions = test.Options()
105+
ctx = options.ToContext(ctx, testOptions)
104106

105107
communityImageVersionsAPI = &fake.CommunityGalleryImageVersionsAPI{}
106108
cigImageVersionTest := cigImageVersion
@@ -110,7 +112,7 @@ var _ = Describe("NodeImageProvider tests", func() {
110112
kubernetesVersion = lo.Must(env.KubernetesInterface.Discovery().ServerVersion()).String()
111113

112114
nodeClass = test.AKSNodeClass()
113-
test.ApplyDefaultStatus(nodeClass, env)
115+
test.ApplyDefaultStatus(nodeClass, env, testOptions.UseSIG)
114116
})
115117

116118
Context("List CIG Images", func() {
@@ -161,7 +163,7 @@ var _ = Describe("NodeImageProvider tests", func() {
161163

162164
Context("List SIG Images", func() {
163165
BeforeEach(func() {
164-
testOptions := options.FromContext(ctx)
166+
testOptions = options.FromContext(ctx)
165167
testOptions.UseSIG = true
166168
testOptions.SIGSubscriptionID = sigSubscription
167169
testOptions.SIGAccessTokenScope = "http://valid-scope.com/.default"

0 commit comments

Comments
 (0)