Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public class AzureAppConfigurationKeyVaultOptions
internal TimeSpan? DefaultSecretRefreshInterval = null;
internal bool IsKeyVaultRefreshConfigured = false;

/// <summary>
/// Flag to indicate whether Key Vault references should be resolved in parallel. Disabled by default.
/// </summary>
public bool ParallelSecretResolutionEnabled { get; set; }

/// <summary>
/// Sets the credentials used to authenticate to key vaults that have no registered <see cref="SecretClient"/>.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ internal IEnumerable<IKeyValueAdapter> Adapters
/// <summary>
/// Flag to indicate whether Key Vault options have been configured.
/// </summary>
internal bool IsKeyVaultConfigured { get; private set; } = false;
internal bool IsKeyVaultConfigured { get; private set; }

/// <summary>
/// Flag to indicate whether Key Vault secret values will be refreshed automatically.
/// </summary>
internal bool IsKeyVaultRefreshConfigured { get; private set; } = false;
internal bool IsKeyVaultRefreshConfigured { get; private set; }

/// <summary>
/// Indicates all feature flag features used by the application.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,16 +625,19 @@ private async Task<Dictionary<string, string>> PrepareData(Dictionary<string, Co
_requestTracingOptions.ResetAiConfigurationTracing();
}

foreach (KeyValuePair<string, ConfigurationSetting> kvp in data)
foreach (IKeyValueAdapter adapter in _options.Adapters)
{
IEnumerable<KeyValuePair<string, string>> keyValuePairs = null;
await adapter.PreloadAsync(data.Values, _logger, cancellationToken).ConfigureAwait(false);
}

foreach (KeyValuePair<string, ConfigurationSetting> kvp in data)
{
if (_requestTracingEnabled && _requestTracingOptions != null)
{
_requestTracingOptions.UpdateAiConfigurationTracing(kvp.Value.ContentType);
}

keyValuePairs = await ProcessAdapters(kvp.Value, cancellationToken).ConfigureAwait(false);
IEnumerable<KeyValuePair<string, string>> keyValuePairs = await ProcessAdapters(kvp.Value, cancellationToken).ConfigureAwait(false);

foreach (KeyValuePair<string, string> kv in keyValuePairs)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,101 @@ public bool NeedsRefresh()
return _secretProvider.ShouldRefreshKeyVaultSecrets();
}

public async Task PreloadAsync(IEnumerable<ConfigurationSetting> settings, Logger logger, CancellationToken cancellationToken)
{
if (settings == null)
{
return;
}

HashSet<Uri> seen = null;
List<(KeyVaultSecretIdentifier Identifier, ConfigurationSetting Setting, string SecretRefUri)> toFetch = null;

foreach (ConfigurationSetting setting in settings)
{
if (!CanProcess(setting))
{
continue;
}

string secretRefUri = ParseSecretReferenceUri(setting);

if (string.IsNullOrEmpty(secretRefUri) ||
!Uri.TryCreate(secretRefUri, UriKind.Absolute, out Uri secretUri) ||
!KeyVaultSecretIdentifier.TryCreate(secretUri, out KeyVaultSecretIdentifier secretIdentifier))
{
throw CreateKeyVaultReferenceException("Invalid Key vault secret identifier.", setting, null, secretRefUri);
}

seen = seen ?? new HashSet<Uri>();

if (!seen.Add(secretIdentifier.SourceId))
{
continue;
}

toFetch = toFetch ?? new List<(KeyVaultSecretIdentifier, ConfigurationSetting, string)>();
toFetch.Add((secretIdentifier, setting, secretRefUri));
}

if (toFetch == null)
{
return;
}

if (_secretProvider.IsParallelSecretResolutionEnabled)
{
using (var throttle = new SemaphoreSlim(KeyVaultConstants.MaxParallelSecretResolution))
{
var tasks = new Task[toFetch.Count];

for (int i = 0; i < toFetch.Count; i++)
{
(KeyVaultSecretIdentifier identifier, ConfigurationSetting setting, string secretRefUri) = toFetch[i];
tasks[i] = PreloadSecretAsync(identifier, setting, secretRefUri, throttle, logger, cancellationToken);
}

await Task.WhenAll(tasks).ConfigureAwait(false);
}
}
else
{
foreach ((KeyVaultSecretIdentifier identifier, ConfigurationSetting setting, string secretRefUri) in toFetch)
{
await PreloadSecretAsync(identifier, setting, secretRefUri, throttle: null, logger, cancellationToken).ConfigureAwait(false);
}
}
}

private async Task PreloadSecretAsync(KeyVaultSecretIdentifier identifier, ConfigurationSetting setting, string secretRefUri, SemaphoreSlim throttle, Logger logger, CancellationToken cancellationToken)
{
if (throttle != null)
{
await throttle.WaitAsync(cancellationToken).ConfigureAwait(false);
}

try
{
await _secretProvider.GetSecretValue(identifier, setting.Key, setting.Label, logger, cancellationToken).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
throw;
}
catch (Exception e) when (e is UnauthorizedAccessException || (e.Source?.Equals(AzureIdentityAssemblyName, StringComparison.OrdinalIgnoreCase) ?? false))
{
throw CreateKeyVaultReferenceException(e.Message, setting, e, secretRefUri);
}
catch (Exception e) when (e is RequestFailedException || ((e as AggregateException)?.InnerExceptions?.All(e => e is RequestFailedException) ?? false))
{
throw CreateKeyVaultReferenceException("Key vault error.", setting, e, secretRefUri);
}
finally
{
throttle?.Release();
}
}

private string ParseSecretReferenceUri(ConfigurationSetting setting)
{
string secretRefUri = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
using Azure.Security.KeyVault.Secrets;
using Microsoft.Extensions.Configuration.AzureAppConfiguration.Extensions;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

Expand All @@ -14,16 +14,16 @@ namespace Microsoft.Extensions.Configuration.AzureAppConfiguration.AzureKeyVault
internal class AzureKeyVaultSecretProvider
{
private readonly AzureAppConfigurationKeyVaultOptions _keyVaultOptions;
private readonly IDictionary<string, SecretClient> _secretClients;
private readonly Dictionary<Uri, CachedKeyVaultSecret> _cachedKeyVaultSecrets;
private Uri _nextRefreshSourceId;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we remove the _nextRefreshSourceId and _nextRefreshTime is to simplify the implementation. Otherwise, we have to add lock here, because they will be updated in the GetSecretValue code path (the SetSecretInCache call)

The benefit of maintaining the _netRefresh pair is that the ShouldRefreshKeyVaultSecrets call will be an O(1) operation. The implementation in this PR makes it an O(n) operation. But I think there won't be too many cached secrets and iterate the whole concurrent dictionary is not that expensive. So I think it is fine.

cc @avanigupta @jimmyca15

private DateTimeOffset? _nextRefreshTime;
private readonly ConcurrentDictionary<string, SecretClient> _secretClients;
private readonly ConcurrentDictionary<Uri, CachedKeyVaultSecret> _cachedKeyVaultSecrets;

public bool IsParallelSecretResolutionEnabled => _keyVaultOptions.ParallelSecretResolutionEnabled;

public AzureKeyVaultSecretProvider(AzureAppConfigurationKeyVaultOptions keyVaultOptions = null)
{
_keyVaultOptions = keyVaultOptions ?? new AzureAppConfigurationKeyVaultOptions();
_cachedKeyVaultSecrets = new Dictionary<Uri, CachedKeyVaultSecret>();
_secretClients = new Dictionary<string, SecretClient>(StringComparer.OrdinalIgnoreCase);
_cachedKeyVaultSecrets = new ConcurrentDictionary<Uri, CachedKeyVaultSecret>();
_secretClients = new ConcurrentDictionary<string, SecretClient>(StringComparer.OrdinalIgnoreCase);

if (_keyVaultOptions.SecretClients != null)
{
Expand Down Expand Up @@ -52,6 +52,7 @@ public async Task<string> GetSecretValue(KeyVaultSecretIdentifier secretIdentifi
throw new UnauthorizedAccessException("No key vault credential or secret resolver callback configured, and no matching secret client could be found.");
}

CachedKeyVaultSecret updatedCachedSecret = null;
bool success = false;

try
Expand All @@ -68,55 +69,44 @@ public async Task<string> GetSecretValue(KeyVaultSecretIdentifier secretIdentifi
secretValue = await _keyVaultOptions.SecretResolver(secretIdentifier.SourceId).ConfigureAwait(false);
}

cachedSecret = new CachedKeyVaultSecret(secretValue, secretIdentifier.SourceId);
updatedCachedSecret = new CachedKeyVaultSecret(secretValue, secretIdentifier.SourceId);
success = true;
}
finally
{
SetSecretInCache(secretIdentifier.SourceId, key, cachedSecret, success);
SetSecretInCache(secretIdentifier.SourceId, key, updatedCachedSecret, success);
}

return secretValue;
}

public bool ShouldRefreshKeyVaultSecrets()
{
return _nextRefreshTime.HasValue && _nextRefreshTime.Value < DateTimeOffset.UtcNow;
}

public void ClearCache()
{
var sourceIdsToRemove = new List<Uri>();

var utcNow = DateTimeOffset.UtcNow;

foreach (KeyValuePair<Uri, CachedKeyVaultSecret> secret in _cachedKeyVaultSecrets)
{
if (secret.Value.LastRefreshTime + RefreshConstants.MinimumSecretRefreshInterval < utcNow)
if (secret.Value.RefreshAt.HasValue && secret.Value.RefreshAt.Value < DateTimeOffset.UtcNow)
{
sourceIdsToRemove.Add(secret.Key);
return true;
}
}

foreach (Uri sourceId in sourceIdsToRemove)
{
_cachedKeyVaultSecrets.Remove(sourceId);
}
return false;
}
Comment thread
linglingye001 marked this conversation as resolved.

if (_cachedKeyVaultSecrets.Any())
public void ClearCache()
{
foreach (KeyValuePair<Uri, CachedKeyVaultSecret> secret in _cachedKeyVaultSecrets)
{
UpdateNextRefreshableSecretFromCache();
if (secret.Value.LastRefreshTime + RefreshConstants.MinimumSecretRefreshInterval < DateTimeOffset.UtcNow)
{
_cachedKeyVaultSecrets.TryRemove(secret.Key, out _);
}
}
}

public void RemoveSecretFromCache(Uri sourceId)
{
_cachedKeyVaultSecrets.Remove(sourceId);

if (sourceId == _nextRefreshSourceId)
{
UpdateNextRefreshableSecretFromCache();
}
_cachedKeyVaultSecrets.TryRemove(sourceId, out _);
}

private SecretClient GetSecretClient(Uri secretUri)
Expand All @@ -133,14 +123,12 @@ private SecretClient GetSecretClient(Uri secretUri)
return null;
}

client = new SecretClient(
new Uri(secretUri.GetLeftPart(UriPartial.Authority)),
_keyVaultOptions.Credential,
_keyVaultOptions.ClientOptions);

_secretClients.Add(keyVaultId, client);

return client;
return _secretClients.GetOrAdd(
keyVaultId,
_ => new SecretClient(
new Uri(secretUri.GetLeftPart(UriPartial.Authority)),
_keyVaultOptions.Credential,
_keyVaultOptions.ClientOptions));
}

private void SetSecretInCache(Uri sourceId, string key, CachedKeyVaultSecret cachedSecret, bool success = true)
Expand All @@ -152,37 +140,6 @@ private void SetSecretInCache(Uri sourceId, string key, CachedKeyVaultSecret cac

UpdateCacheExpirationTimeForSecret(key, cachedSecret, success);
_cachedKeyVaultSecrets[sourceId] = cachedSecret;

if (sourceId == _nextRefreshSourceId)
{
UpdateNextRefreshableSecretFromCache();
}
else if ((cachedSecret.RefreshAt.HasValue && _nextRefreshTime.HasValue && cachedSecret.RefreshAt.Value < _nextRefreshTime.Value)
|| (cachedSecret.RefreshAt.HasValue && !_nextRefreshTime.HasValue))
{
_nextRefreshSourceId = sourceId;
_nextRefreshTime = cachedSecret.RefreshAt.Value;
}
}

private void UpdateNextRefreshableSecretFromCache()
{
_nextRefreshSourceId = null;
_nextRefreshTime = DateTimeOffset.MaxValue;

foreach (KeyValuePair<Uri, CachedKeyVaultSecret> secret in _cachedKeyVaultSecrets)
{
if (secret.Value.RefreshAt.HasValue && secret.Value.RefreshAt.Value < _nextRefreshTime)
{
_nextRefreshTime = secret.Value.RefreshAt;
_nextRefreshSourceId = secret.Key;
}
}

if (_nextRefreshTime == DateTimeOffset.MaxValue)
{
_nextRefreshTime = null;
}
}

private void UpdateCacheExpirationTimeForSecret(string key, CachedKeyVaultSecret cachedSecret, bool success)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ internal class KeyVaultConstants
public const string ContentType = "application/vnd.microsoft.appconfig.keyvaultref+json";

public const string SecretReferenceUriJsonPropertyName = "uri";

public const int MaxParallelSecretResolution = 16;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ public void OnConfigUpdated()
return;
}

public Task PreloadAsync(IEnumerable<ConfigurationSetting> settings, Logger logger, CancellationToken cancellationToken)
{
return Task.CompletedTask;
}

private List<KeyValuePair<string, string>> ProcessDotnetSchemaFeatureFlag(FeatureFlag featureFlag, ConfigurationSetting setting, Uri endpoint)
{
var keyValues = new List<KeyValuePair<string, string>>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ internal interface IKeyValueAdapter
{
Task<IEnumerable<KeyValuePair<string, string>>> ProcessKeyValue(ConfigurationSetting setting, Uri endpoint, Logger logger, CancellationToken cancellationToken);

Task PreloadAsync(IEnumerable<ConfigurationSetting> settings, Logger logger, CancellationToken cancellationToken);

bool CanProcess(ConfigurationSetting setting);

void OnChangeDetected(ConfigurationSetting setting = null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,10 @@ public bool NeedsRefresh()
{
return false;
}

public Task PreloadAsync(IEnumerable<ConfigurationSetting> settings, Logger logger, CancellationToken cancellationToken)
{
return Task.CompletedTask;
}
}
}
Loading
Loading