diff --git a/README.md b/README.md index 35978a6..9953997 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Source generator that helps register attribute marked services in the dependency - Module method registration - Duplicate Strategy - Skip,Replace,Append - Registration Strategy - Self, Implemented Interfaces, Self With Interfaces +- Decorator registration (`RegisterDecorator`) — no runtime dependencies ### Usage @@ -40,6 +41,7 @@ Place registration attribute on class. The class will be discovered and registe - `[RegisterScoped]` Marks the class as a scoped service - `[RegisterTransient]` Marks the class as a transient service - `[RegisterServices]` Marks the method to be called to register services +- `[RegisterDecorator]` Marks the class as a decorator around an existing service #### Attribute Properties @@ -217,6 +219,119 @@ public class ServiceFactoryKeyed : IServiceKeyed } ``` +#### Decorators + +Use the `RegisterDecorator` attribute to wrap an existing service registration without adding +any runtime dependencies. The generator emits all decoration helpers directly into the +consumer assembly. + +Decorators inherit the lifetime of the service they decorate. Apply multiple decorators by +ordering them with the `Order` property — lower values are innermost (applied first), higher +values are outermost (applied last). + +```c# +public interface IService { } + +[RegisterSingleton] +public class Service : IService { } + +[RegisterDecorator(Order = 1)] +public class LoggingDecorator : IService +{ + public LoggingDecorator(IService inner) { } +} + +[RegisterDecorator(Order = 2)] +public class CachingDecorator : IService +{ + public CachingDecorator(IService inner) { } +} +``` + +Resolution order for the sample above: `CachingDecorator → LoggingDecorator → Service`. + +##### Decorator Attribute Properties + +| Property | Description | +|--------------------|------------------------------------------------------------------------------------------------| +| ServiceType | The type of service to decorate. Required unless the generic attribute form is used. | +| ImplementationType | The decorator type. If not set, the class the attribute is on will be used. | +| ServiceKey | Decorate a specific keyed registration. Requires .NET 8+ Microsoft.Extensions.DependencyInjection. | +| AnyKey | When `true`, decorate every keyed registration of `ServiceType` regardless of its key. | +| Factory | Name of a static factory method that builds the decorator. | +| Order | Ordering within the decoration chain. Lower = innermost. | +| Tags | Comma/semicolon-delimited list of registration tags. | + +##### Keyed decoration + +Decorate a single keyed variant, or use `AnyKey` to decorate them all: + +```c# +[RegisterSingleton(ServiceKey = "alpha")] +public class AlphaService : IService { } + +[RegisterDecorator(AnyKey = true)] +public class LoggingDecorator : IService +{ + public LoggingDecorator(IService inner) { } +} +``` + +##### Factory-built decorators + +Provide a static factory on the decorator class for complex construction: + +```c# +[RegisterDecorator(Factory = nameof(Create))] +public class LoggingDecorator : IService +{ + public LoggingDecorator(IService inner) { } + + public static IService Create(IServiceProvider serviceProvider, IService inner) + => new LoggingDecorator(inner); +} +``` + +For keyed decorators the factory takes an additional `object?` parameter for the key: + +```c# +public static IService Create(IServiceProvider serviceProvider, object? serviceKey, IService inner) + => new LoggingDecorator(inner); +``` + +##### Open-generic decoration + +Open-generic decorators apply to every closed registration of the matching service type. +The generator supports decorating closed-generic registrations with an open-generic decorator +class; purely open-generic implementation registrations (e.g. `(IRepo<>, Repo<>)`) are not +decorated at runtime due to a Microsoft.Extensions.DependencyInjection limitation on factory +registrations for open generic service types. + +```c# +public interface IRepo { } + +[RegisterSingleton, StringRepo>] +public class StringRepo : IRepo { } + +[RegisterDecorator(ServiceType = typeof(IRepo<>))] +public class LoggingRepo : IRepo +{ + public LoggingRepo(IRepo inner) { } +} +``` + +##### Tags + +Decorators support the same tag-filtering as registrations: + +```c# +[RegisterDecorator(Tags = "FrontEnd")] +public class FrontEndLoggingDecorator : IService +{ + public FrontEndLoggingDecorator(IService inner) { } +} +``` + #### Register Method When the service registration is complex, use the `RegisterServices` attribute on a method that has a parameter of `IServiceCollection` or `ServiceCollection` diff --git a/src/Injectio.Attributes/RegisterDecoratorAttribute.cs b/src/Injectio.Attributes/RegisterDecoratorAttribute.cs new file mode 100644 index 0000000..8954583 --- /dev/null +++ b/src/Injectio.Attributes/RegisterDecoratorAttribute.cs @@ -0,0 +1,111 @@ +namespace Injectio.Attributes; + +/// +/// Attribute to indicate the target class should be registered as a decorator for an existing service. +/// The decorator wraps the previously registered service implementation and inherits its . +/// +/// Decorate IService with a logging wrapper +/// +/// [RegisterDecorator(ServiceType = typeof(IService))] +/// public class LoggingDecorator : IService +/// { +/// public LoggingDecorator(IService inner) { } +/// } +/// +/// +[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)] +[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] +public class RegisterDecoratorAttribute : Attribute +{ + /// + /// The of the service to decorate. + /// + public Type? ServiceType { get; set; } + + /// + /// The that implements the decorator. If not set, the class the attribute is on will be used. + /// + public Type? ImplementationType { get; set; } + + /// + /// Gets or sets the key of the keyed service to decorate. + /// Leave unset (and false) to decorate the non-keyed registration. + /// + public object? ServiceKey { get; set; } + + /// + /// When true, the decorator is applied to every keyed registration of , + /// regardless of its key. Equivalent to KeyedService.AnyKey. + /// + public bool AnyKey { get; set; } + + /// + /// Name of a static factory method to construct the decorator. + /// + /// + /// The factory signature must be (IServiceProvider, TService) -> TService for non-keyed services + /// or (IServiceProvider, object?, TService) -> TService for keyed services. + /// + public string? Factory { get; set; } + + /// + /// Gets or sets the order in which the decorator is applied. Lower values are applied first (innermost). + /// + public int Order { get; set; } + + /// + /// Gets or sets the comma delimited list of registration tags. + /// + public string? Tags { get; set; } +} + +#if NET7_0_OR_GREATER +/// +/// Attribute to indicate the target class should be registered as a decorator for . +/// +/// The type of the service to decorate. +/// +/// +/// [RegisterDecorator<IService>] +/// public class LoggingDecorator : IService +/// { +/// public LoggingDecorator(IService inner) { } +/// } +/// +/// +[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)] +[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] +public class RegisterDecoratorAttribute : RegisterDecoratorAttribute + where TService : class +{ + /// + /// Initializes a new instance of the class. + /// + public RegisterDecoratorAttribute() + { + ServiceType = typeof(TService); + } +} + +/// +/// Attribute to indicate the target class should be registered as a decorator for +/// using as the decorator implementation. +/// +/// The type of the service to decorate. +/// The type of the decorator implementation. +[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)] +[System.Diagnostics.Conditional("REGISTER_SERVICE_USAGES")] +public class RegisterDecoratorAttribute : RegisterDecoratorAttribute + where TService : class + where TImplementation : class, TService +{ + /// + /// Initializes a new instance of the class. + /// + public RegisterDecoratorAttribute() + { + ServiceType = typeof(TService); + ImplementationType = typeof(TImplementation); + } +} +#endif diff --git a/src/Injectio.Generators/AnalyzerReleases.Unshipped.md b/src/Injectio.Generators/AnalyzerReleases.Unshipped.md index 65956d6..9ecbdb3 100644 --- a/src/Injectio.Generators/AnalyzerReleases.Unshipped.md +++ b/src/Injectio.Generators/AnalyzerReleases.Unshipped.md @@ -14,3 +14,10 @@ INJ0006 | Usage | Warning | Factory method has invalid signature INJ0007 | Usage | Warning | Implementation does not implement service type INJ0008 | Usage | Warning | Implementation type is abstract INJ0009 | Usage | Warning | RegisterServices on non-static method in abstract class +INJ0010 | Usage | Warning | Decorator does not implement service type +INJ0011 | Usage | Warning | Decorator is missing service type +INJ0012 | Usage | Warning | Decorator has no constructor accepting the inner service +INJ0013 | Usage | Warning | Decorator factory method not found +INJ0014 | Usage | Warning | Decorator factory method has invalid signature +INJ0015 | Usage | Warning | Keyed decoration is not supported for open-generic services +INJ0016 | Usage | Warning | Decorator target service is not registered in this compilation diff --git a/src/Injectio.Generators/DecoratorRegistration.cs b/src/Injectio.Generators/DecoratorRegistration.cs new file mode 100644 index 0000000..c1796f8 --- /dev/null +++ b/src/Injectio.Generators/DecoratorRegistration.cs @@ -0,0 +1,12 @@ +namespace Injectio.Generators; + +public record DecoratorRegistration( + string DecoratorType, + string ServiceType, + string? ServiceKey, + bool IsAnyKey, + string? Factory, + int Order, + EquatableArray Tags, + bool IsOpenGeneric = false +); diff --git a/src/Injectio.Generators/DiagnosticDescriptors.cs b/src/Injectio.Generators/DiagnosticDescriptors.cs index 01c7157..f31812b 100644 --- a/src/Injectio.Generators/DiagnosticDescriptors.cs +++ b/src/Injectio.Generators/DiagnosticDescriptors.cs @@ -86,4 +86,69 @@ public static class DiagnosticDescriptors defaultSeverity: DiagnosticSeverity.Warning, isEnabledByDefault: true ); + + public static readonly DiagnosticDescriptor DecoratorDoesNotImplementService = new( + id: "INJ0010", + title: "Decorator does not implement service type", + messageFormat: "Decorator '{0}' does not implement or inherit from service type '{1}'", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true + ); + + public static readonly DiagnosticDescriptor DecoratorMissingServiceType = new( + id: "INJ0011", + title: "Decorator is missing service type", + messageFormat: "Decorator '{0}' must specify a ServiceType either via the generic attribute or the ServiceType property", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true + ); + + public static readonly DiagnosticDescriptor DecoratorMissingInnerConstructor = new( + id: "INJ0012", + title: "Decorator has no constructor accepting the inner service", + messageFormat: "Decorator '{0}' must expose a public constructor whose first parameter is of type '{1}' (or use Factory)", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true + ); + + public static readonly DiagnosticDescriptor DecoratorFactoryNotFound = new( + id: "INJ0013", + title: "Decorator factory method not found", + messageFormat: "Decorator factory method '{0}' was not found on type '{1}'", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true + ); + + public static readonly DiagnosticDescriptor DecoratorFactoryInvalidSignature = new( + id: "INJ0014", + title: "Decorator factory method has invalid signature", + messageFormat: "Decorator factory method '{0}' on type '{1}' must be static and accept (IServiceProvider, TService) for non-keyed or (IServiceProvider, object?, TService) for keyed decorators", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true + ); + + public static readonly DiagnosticDescriptor DecoratorOpenGenericKeyed = new( + id: "INJ0015", + title: "Keyed decoration is not supported for open-generic services", + messageFormat: "Decorator '{0}' targets open-generic service '{1}' and cannot be used with ServiceKey or AnyKey", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true, + customTags: [WellKnownDiagnosticTags.CompilationEnd] + ); + + public static readonly DiagnosticDescriptor DecoratorTargetNotRegistered = new( + id: "INJ0016", + title: "Decorator target service is not registered in this compilation", + messageFormat: "Decorator '{0}' targets service '{1}' but no matching registration was found; decoration will be skipped at runtime if the service is not registered elsewhere", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true, + customTags: [WellKnownDiagnosticTags.CompilationEnd] + ); } diff --git a/src/Injectio.Generators/KnownTypes.cs b/src/Injectio.Generators/KnownTypes.cs index 77efad4..ce1f0aa 100644 --- a/src/Injectio.Generators/KnownTypes.cs +++ b/src/Injectio.Generators/KnownTypes.cs @@ -23,6 +23,10 @@ public static class KnownTypes public const string ModuleAttributeTypeName = $"{ModuleAttributeShortName}Attribute"; public const string ModuleAttributeFullName = $"{AbstractionNamespace}.{ModuleAttributeTypeName}"; + public const string DecoratorAttributeShortName = "RegisterDecorator"; + public const string DecoratorAttributeTypeName = $"{DecoratorAttributeShortName}Attribute"; + public const string DecoratorAttributeFullName = $"{AbstractionNamespace}.{DecoratorAttributeTypeName}"; + public const string ServiceLifetimeSingletonShortName = "Singleton"; public const string ServiceLifetimeSingletonTypeName = $"ServiceLifetime.{ServiceLifetimeSingletonShortName}"; diff --git a/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs b/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs index dbec02e..40be119 100644 --- a/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs +++ b/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs @@ -20,7 +20,14 @@ public class ServiceRegistrationAnalyzer : DiagnosticAnalyzer DiagnosticDescriptors.FactoryMethodInvalidSignature, DiagnosticDescriptors.ServiceTypeMismatch, DiagnosticDescriptors.AbstractImplementationType, - DiagnosticDescriptors.RegisterServicesMethodOnAbstractClass); + DiagnosticDescriptors.RegisterServicesMethodOnAbstractClass, + DiagnosticDescriptors.DecoratorDoesNotImplementService, + DiagnosticDescriptors.DecoratorMissingServiceType, + DiagnosticDescriptors.DecoratorMissingInnerConstructor, + DiagnosticDescriptors.DecoratorFactoryNotFound, + DiagnosticDescriptors.DecoratorFactoryInvalidSignature, + DiagnosticDescriptors.DecoratorOpenGenericKeyed, + DiagnosticDescriptors.DecoratorTargetNotRegistered); public override void Initialize(AnalysisContext context) { @@ -29,6 +36,234 @@ public override void Initialize(AnalysisContext context) context.RegisterSymbolAction(AnalyzeMethod, SymbolKind.Method); context.RegisterSymbolAction(AnalyzeNamedType, SymbolKind.NamedType); + context.RegisterCompilationStartAction(AnalyzeCompilation); + } + + private static void AnalyzeCompilation(CompilationStartAnalysisContext context) + { + var decorators = new List(); + var registeredServices = new HashSet(StringComparer.Ordinal); + var hasModule = false; + + context.RegisterSymbolAction(symbolContext => + { + if (symbolContext.Symbol is IMethodSymbol methodSymbol) + { + foreach (var attribute in methodSymbol.GetAttributes()) + { + if (SymbolHelpers.IsMethodAttribute(attribute)) + { + lock (registeredServices) hasModule = true; + break; + } + } + return; + } + + if (symbolContext.Symbol is not INamedTypeSymbol classSymbol) + return; + + if (classSymbol.IsStatic) + return; + + var attributes = classSymbol.GetAttributes(); + + foreach (var attribute in attributes) + { + if (SymbolHelpers.IsDecoratorAttribute(attribute)) + { + var info = ExtractDecoratorInfo(classSymbol, attribute); + if (info != null) + lock (decorators) decorators.Add(info); + continue; + } + + if (!SymbolHelpers.IsKnownAttribute(attribute, out _)) + continue; + + CollectRegisteredServiceTypes(classSymbol, attribute, registeredServices); + } + }, SymbolKind.NamedType, SymbolKind.Method); + + context.RegisterCompilationEndAction(endContext => + { + foreach (var decorator in decorators) + { + // INJ0015 — open-generic + keyed combination + if (decorator.IsOpenGeneric && (decorator.HasServiceKey || decorator.IsAnyKey)) + { + endContext.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorOpenGenericKeyed, + decorator.Location, + decorator.DecoratorType, + decorator.ServiceType)); + } + + // INJ0016 — target not registered (only when no module could register it dynamically) + if (!hasModule && !decorator.HasServiceKey && !decorator.IsAnyKey) + { + if (!registeredServices.Contains(decorator.ServiceType)) + { + endContext.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorTargetNotRegistered, + decorator.Location, + decorator.DecoratorType, + decorator.ServiceType)); + } + } + } + }); + } + + private sealed class DecoratorInfo + { + public string DecoratorType { get; set; } = string.Empty; + public string ServiceType { get; set; } = string.Empty; + public bool IsOpenGeneric { get; set; } + public bool HasServiceKey { get; set; } + public bool IsAnyKey { get; set; } + public Location Location { get; set; } = Location.None; + } + + private static DecoratorInfo? ExtractDecoratorInfo(INamedTypeSymbol classSymbol, AttributeData attribute) + { + string? serviceType = null; + bool hasServiceKey = false; + bool isAnyKey = false; + bool isOpenGeneric = false; + + var attributeClass = attribute.AttributeClass; + if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length >= 1) + { + var typeArgument = attributeClass.TypeArguments[0]; + if (typeArgument is INamedTypeSymbol namedService && namedService.IsGenericType && namedService.IsUnboundGenericType) + isOpenGeneric = true; + + serviceType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + + foreach (var parameter in attribute.NamedArguments) + { + var name = parameter.Key; + var value = parameter.Value.Value; + + switch (name) + { + case "ServiceType": + if (value is INamedTypeSymbol svc) + { + if (svc.IsGenericType && svc.IsUnboundGenericType) + isOpenGeneric = true; + serviceType = svc.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + break; + case "ServiceKey": + hasServiceKey = value is not null; + break; + case "AnyKey": + if (value is bool b) + isAnyKey = b; + break; + } + } + + if (serviceType is null) + return null; + + var location = classSymbol.Locations.Length > 0 ? classSymbol.Locations[0] : Location.None; + + return new DecoratorInfo + { + DecoratorType = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat), + ServiceType = serviceType, + IsOpenGeneric = isOpenGeneric, + HasServiceKey = hasServiceKey, + IsAnyKey = isAnyKey, + Location = location, + }; + } + + private static void CollectRegisteredServiceTypes( + INamedTypeSymbol classSymbol, + AttributeData attribute, + HashSet registeredServices) + { + string? implementationType = null; + string? registrationStrategy = null; + var localServiceTypes = new List(); + + var attributeClass = attribute.AttributeClass; + if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length == attributeClass.TypeParameters.Length) + { + for (var index = 0; index < attributeClass.TypeParameters.Length; index++) + { + var typeParameter = attributeClass.TypeParameters[index]; + var typeArgument = attributeClass.TypeArguments[index]; + + if (typeParameter.Name == "TService" || index == 0) + localServiceTypes.Add(typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat)); + else if (typeParameter.Name == "TImplementation" || index == 1) + implementationType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + } + + foreach (var parameter in attribute.NamedArguments) + { + var name = parameter.Key; + var value = parameter.Value.Value; + if (value is null) continue; + + switch (name) + { + case "ServiceType": + if (value is INamedTypeSymbol svc) + localServiceTypes.Add(svc.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat)); + break; + case "ImplementationType": + if (value is INamedTypeSymbol impl) + implementationType = impl.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + break; + case "Registration": + registrationStrategy = SymbolHelpers.ResolveRegistrationStrategy(value); + break; + } + } + + implementationType ??= classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + + if (registrationStrategy == null && localServiceTypes.Count == 0) + registrationStrategy = KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName; + + bool includeInterfaces = registrationStrategy is KnownTypes.RegistrationStrategyImplementedInterfacesShortName + or KnownTypes.RegistrationStrategySelfWithInterfacesShortName + or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName + or null; + + if (includeInterfaces) + { + foreach (var iface in classSymbol.AllInterfaces) + { + if (iface.ConstructedFrom.ToString() == "System.IEquatable") + continue; + + var unbound = SymbolHelpers.ToUnboundGenericType(iface); + localServiceTypes.Add(unbound.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat)); + } + } + + bool includeSelf = registrationStrategy is KnownTypes.RegistrationStrategySelfShortName + or KnownTypes.RegistrationStrategySelfWithInterfacesShortName + or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName + or null; + + if (includeSelf || localServiceTypes.Count == 0) + localServiceTypes.Add(implementationType); + + lock (registeredServices) + { + foreach (var t in localServiceTypes) + registeredServices.Add(t); + } } private static void AnalyzeMethod(SymbolAnalysisContext context) @@ -126,17 +361,249 @@ private static void AnalyzeNamedType(SymbolAnalysisContext context) foreach (var attribute in attributes) { - if (!SymbolHelpers.IsKnownAttribute(attribute, out _)) - continue; - var location = classSymbol.Locations.Length > 0 ? classSymbol.Locations[0] : Location.None; + if (SymbolHelpers.IsDecoratorAttribute(attribute)) + { + AnalyzeDecoratorAttribute(context, classSymbol, attribute, location); + continue; + } + + if (!SymbolHelpers.IsKnownAttribute(attribute, out _)) + continue; + AnalyzeRegistrationAttribute(context, classSymbol, attribute, location); } } + private static void AnalyzeDecoratorAttribute( + SymbolAnalysisContext context, + INamedTypeSymbol classSymbol, + AttributeData attribute, + Location location) + { + string? serviceTypeName = null; + INamedTypeSymbol? serviceTypeSymbol = null; + string? factory = null; + bool hasServiceKey = false; + bool isAnyKey = false; + + var attributeClass = attribute.AttributeClass; + if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length >= 1) + { + if (attributeClass.TypeArguments[0] is INamedTypeSymbol serviceArg) + { + serviceTypeSymbol = serviceArg; + serviceTypeName = serviceArg.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + } + + foreach (var parameter in attribute.NamedArguments) + { + var name = parameter.Key; + var value = parameter.Value.Value; + + switch (name) + { + case "ServiceType": + if (value is INamedTypeSymbol svc) + { + serviceTypeSymbol = svc; + serviceTypeName = svc.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + break; + case "Factory": + factory = value?.ToString(); + break; + case "ServiceKey": + hasServiceKey = value is not null; + break; + case "AnyKey": + if (value is bool b) + isAnyKey = b; + break; + } + } + + // INJ0011 — missing service type + if (serviceTypeName is null) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorMissingServiceType, + location, + classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat))); + return; + } + + // INJ0010 — class does not implement service + var classTypeName = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (serviceTypeName != classTypeName) + { + var implementsService = false; + + foreach (var iface in classSymbol.AllInterfaces) + { + var ifaceName = iface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (ifaceName == serviceTypeName) + { + implementsService = true; + break; + } + + var unboundIface = SymbolHelpers.ToUnboundGenericType(iface); + if (!SymbolEqualityComparer.Default.Equals(unboundIface, iface)) + { + var unboundName = unboundIface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (unboundName == serviceTypeName) + { + implementsService = true; + break; + } + } + } + + if (!implementsService) + { + var baseType = classSymbol.BaseType; + while (baseType is not null) + { + var baseName = baseType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (baseName == serviceTypeName) { implementsService = true; break; } + + var unboundBase = SymbolHelpers.ToUnboundGenericType(baseType); + if (!SymbolEqualityComparer.Default.Equals(unboundBase, baseType)) + { + var unboundBaseName = unboundBase.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (unboundBaseName == serviceTypeName) { implementsService = true; break; } + } + + baseType = baseType.BaseType; + } + } + + if (!implementsService) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorDoesNotImplementService, + location, + classTypeName, + serviceTypeName)); + } + } + + // INJ0012/13/14 — constructor or factory validation + if (factory.HasValue()) + { + ValidateDecoratorFactory(context, classSymbol, factory!, hasServiceKey || isAnyKey, location); + } + else + { + var hasCompatibleCtor = false; + foreach (var ctor in classSymbol.InstanceConstructors) + { + if (ctor.DeclaredAccessibility == Accessibility.Private) continue; + if (ctor.Parameters.Length == 0) continue; + + var firstParamType = ctor.Parameters[0].Type.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (firstParamType == serviceTypeName) + { + hasCompatibleCtor = true; + break; + } + + if (ctor.Parameters[0].Type is INamedTypeSymbol paramNamed) + { + var unboundParam = SymbolHelpers.ToUnboundGenericType(paramNamed); + if (!SymbolEqualityComparer.Default.Equals(unboundParam, paramNamed)) + { + var unboundName = unboundParam.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (unboundName == serviceTypeName) + { + hasCompatibleCtor = true; + break; + } + } + } + + // any parameter matches? + foreach (var param in ctor.Parameters) + { + var paramType = param.Type.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (paramType == serviceTypeName) + { + hasCompatibleCtor = true; + break; + } + } + + if (hasCompatibleCtor) break; + } + + if (!hasCompatibleCtor && classSymbol.InstanceConstructors.Length > 0) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorMissingInnerConstructor, + location, + classTypeName, + serviceTypeName)); + } + } + } + + private static void ValidateDecoratorFactory( + SymbolAnalysisContext context, + INamedTypeSymbol classSymbol, + string factoryMethodName, + bool isKeyed, + Location location) + { + var className = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + var members = classSymbol.GetMembers(factoryMethodName); + var factoryMethods = new List(); + + foreach (var member in members) + { + if (member is IMethodSymbol method) + factoryMethods.Add(method); + } + + if (factoryMethods.Count == 0) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorFactoryNotFound, + location, + factoryMethodName, + className)); + return; + } + + var expectedParamCount = isKeyed ? 3 : 2; + + foreach (var method in factoryMethods) + { + if (!method.IsStatic) continue; + if (method.Parameters.Length != expectedParamCount) continue; + + if (!SymbolHelpers.IsServiceProvider(method.Parameters[0])) continue; + + if (isKeyed) + { + if (method.Parameters[1].Type.SpecialType != SpecialType.System_Object) continue; + // parameter[2] is the inner service — not strictly checked + } + + return; // valid overload found + } + + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.DecoratorFactoryInvalidSignature, + location, + factoryMethodName, + className)); + } + private static void AnalyzeRegistrationAttribute( SymbolAnalysisContext context, INamedTypeSymbol classSymbol, diff --git a/src/Injectio.Generators/ServiceRegistrationContext.cs b/src/Injectio.Generators/ServiceRegistrationContext.cs index 2451cce..009f672 100644 --- a/src/Injectio.Generators/ServiceRegistrationContext.cs +++ b/src/Injectio.Generators/ServiceRegistrationContext.cs @@ -6,5 +6,6 @@ namespace Injectio.Generators; public record ServiceRegistrationContext( EquatableArray? ServiceRegistrations = null, - EquatableArray? ModuleRegistrations = null + EquatableArray? ModuleRegistrations = null, + EquatableArray? DecoratorRegistrations = null ); diff --git a/src/Injectio.Generators/ServiceRegistrationGenerator.cs b/src/Injectio.Generators/ServiceRegistrationGenerator.cs index d4be12c..7dcabdf 100644 --- a/src/Injectio.Generators/ServiceRegistrationGenerator.cs +++ b/src/Injectio.Generators/ServiceRegistrationGenerator.cs @@ -24,7 +24,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) ) .Where(static context => context is not null - && (context.ServiceRegistrations?.Count > 0 || context.ModuleRegistrations?.Count > 0) + && (context.ServiceRegistrations?.Count > 0 + || context.ModuleRegistrations?.Count > 0 + || context.DecoratorRegistrations?.Count > 0) ) .Collect() .WithTrackingName("Registrations"); @@ -64,6 +66,14 @@ private void ExecuteGeneration( .Where(m => m is not null) .ToArray(); + var decoratorRegistrations = source.Registrations + .SelectMany(m => m?.DecoratorRegistrations ?? Array.Empty()) + .Where(m => m is not null) + .OrderBy(m => m.ServiceType, StringComparer.Ordinal) + .ThenBy(m => m.Order) + .ThenBy(m => m.DecoratorType, StringComparer.Ordinal) + .ToArray(); + // compute extension method name var methodName = source.Options.MethodOptions?.Name; if (methodName.IsNullOrWhiteSpace()) @@ -75,12 +85,20 @@ private void ExecuteGeneration( var result = ServiceRegistrationWriter.GenerateExtensionClass( moduleRegistrations, serviceRegistrations, + decoratorRegistrations, source.Options.AssemblyName, methodName, methodInternal); // add source file sourceContext.AddSource("Injectio.g.cs", SourceText.From(result, Encoding.UTF8)); + + // emit decoration helper if any decorators discovered + if (decoratorRegistrations.Length > 0) + { + var decorationHelper = ServiceRegistrationWriter.GenerateDecorationHelper(); + sourceContext.AddSource("Injectio.Decoration.g.cs", SourceText.From(decorationHelper, Encoding.UTF8)); + } } private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken cancellationToken) @@ -156,18 +174,148 @@ private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken // support multiple register attributes on a class var registrations = new List(); + var decorators = new List(); foreach (var attribute in attributes) { + if (SymbolHelpers.IsDecoratorAttribute(attribute)) + { + var decorator = CreateDecoratorRegistration(classSymbol, attribute); + if (decorator is not null) + decorators.Add(decorator); + continue; + } + var registration = CreateServiceRegistration(classSymbol, attribute); if (registration is not null) registrations.Add(registration); } - if (registrations.Count == 0) + if (registrations.Count == 0 && decorators.Count == 0) return null; - return new ServiceRegistrationContext(ServiceRegistrations: registrations.ToArray()); + EquatableArray? serviceArray = registrations.Count > 0 + ? new EquatableArray(registrations.ToArray()) + : (EquatableArray?)null; + EquatableArray? decoratorArray = decorators.Count > 0 + ? new EquatableArray(decorators.ToArray()) + : (EquatableArray?)null; + + return new ServiceRegistrationContext( + ServiceRegistrations: serviceArray, + DecoratorRegistrations: decoratorArray); + } + + private static DecoratorRegistration? CreateDecoratorRegistration(INamedTypeSymbol classSymbol, AttributeData attribute) + { + string? serviceType = null; + string? implementationType = null; + string? serviceKey = null; + bool isAnyKey = false; + string? factory = null; + int order = 0; + var tags = new HashSet(); + bool isOpenGeneric = false; + + var attributeClass = attribute.AttributeClass; + if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length == attributeClass.TypeParameters.Length) + { + for (var index = 0; index < attributeClass.TypeParameters.Length; index++) + { + var typeParameter = attributeClass.TypeParameters[index]; + var typeArgument = attributeClass.TypeArguments[index]; + + if (typeParameter.Name == "TService" || index == 0) + { + isOpenGeneric = isOpenGeneric || IsOpenGeneric(typeArgument as INamedTypeSymbol); + serviceType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + else if (typeParameter.Name == "TImplementation" || index == 1) + { + isOpenGeneric = isOpenGeneric || IsOpenGeneric(typeArgument as INamedTypeSymbol); + implementationType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + } + } + + foreach (var parameter in attribute.NamedArguments) + { + var name = parameter.Key; + var value = parameter.Value.Value; + + if (string.IsNullOrEmpty(name)) + continue; + + switch (name) + { + case "ServiceType": + if (value is INamedTypeSymbol serviceTypeSymbol) + { + isOpenGeneric = isOpenGeneric || IsOpenGeneric(serviceTypeSymbol); + serviceType = serviceTypeSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + else if (value != null) + { + serviceType = value.ToString(); + } + break; + case "ImplementationType": + if (value is INamedTypeSymbol implSymbol) + { + isOpenGeneric = isOpenGeneric || IsOpenGeneric(implSymbol); + implementationType = implSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + else if (value != null) + { + implementationType = value.ToString(); + } + break; + case "ServiceKey": + serviceKey = parameter.Value.ToCSharpString(); + break; + case "AnyKey": + if (value is bool anyKey) + isAnyKey = anyKey; + break; + case "Factory": + factory = value?.ToString(); + break; + case "Order": + if (value is int orderValue) + order = orderValue; + break; + case "Tags": + if (value is string tagsText) + { + foreach (var tag in tagsText.Split(',', ';')) + { + if (tag.HasValue()) + tags.Add(tag.Trim()); + } + } + break; + } + } + + if (implementationType.IsNullOrWhiteSpace()) + { + var unboundType = SymbolHelpers.ToUnboundGenericType(classSymbol); + isOpenGeneric = isOpenGeneric || IsOpenGeneric(unboundType); + implementationType = unboundType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + + if (serviceType.IsNullOrWhiteSpace()) + return null; + + return new DecoratorRegistration( + DecoratorType: implementationType!, + ServiceType: serviceType!, + ServiceKey: serviceKey, + IsAnyKey: isAnyKey, + Factory: factory, + Order: order, + Tags: tags.ToArray(), + IsOpenGeneric: isOpenGeneric); } private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbol methodSymbol) diff --git a/src/Injectio.Generators/ServiceRegistrationWriter.cs b/src/Injectio.Generators/ServiceRegistrationWriter.cs index a393807..a9e8951 100644 --- a/src/Injectio.Generators/ServiceRegistrationWriter.cs +++ b/src/Injectio.Generators/ServiceRegistrationWriter.cs @@ -11,6 +11,15 @@ public static string GenerateExtensionClass( string? assemblyName, string? methodName, string? methodInternal) + => GenerateExtensionClass(moduleRegistrations, serviceRegistrations, Array.Empty(), assemblyName, methodName, methodInternal); + + public static string GenerateExtensionClass( + IReadOnlyList moduleRegistrations, + IReadOnlyList serviceRegistrations, + IReadOnlyList decoratorRegistrations, + string? assemblyName, + string? methodName, + string? methodInternal) { var codeBuilder = new IndentedStringBuilder(); codeBuilder @@ -66,6 +75,11 @@ public static string GenerateExtensionClass( WriteRegistration(codeBuilder, serviceRegistration); } + foreach (var decoratorRegistration in decoratorRegistrations) + { + WriteDecorator(codeBuilder, decoratorRegistration); + } + codeBuilder .AppendLine("return serviceCollection;") .DecrementIndent() @@ -320,6 +334,366 @@ private static void WriteServiceGeneric( .AppendLine(); } + private static void WriteDecorator( + IndentedStringBuilder codeBuilder, + DecoratorRegistration decorator) + { + if (decorator.Tags.Count > 0) + { + codeBuilder + .Append("if (tagSet.Count == 0 || tagSet.Intersect(new[] { "); + + bool wroteTag = false; + foreach (var tag in decorator.Tags) + { + if (wroteTag) + codeBuilder.Append(", "); + + codeBuilder + .Append("\"") + .Append(tag) + .Append("\""); + + wroteTag = true; + } + + codeBuilder + .AppendLine(" }).Any())") + .AppendLine("{") + .IncrementIndent(); + } + + var serviceType = decorator.ServiceType; + var decoratorType = decorator.DecoratorType; + bool hasServiceKey = decorator.ServiceKey.HasValue(); + bool isKeyed = hasServiceKey || decorator.IsAnyKey; + + // resolve the service key expression passed to the helper + string keyExpression; + if (decorator.IsAnyKey) + keyExpression = "global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey"; + else if (hasServiceKey) + keyExpression = decorator.ServiceKey!; + else + keyExpression = "null"; + + if (decorator.IsOpenGeneric) + { + codeBuilder + .Append("global::Injectio.Internal.InjectioDecorationExtensions.DecorateOpenGeneric(") + .AppendLine() + .IncrementIndent() + .AppendLine("serviceCollection,") + .Append("typeof(") + .AppendIf("global::", !serviceType.StartsWith("global::")) + .Append(serviceType) + .AppendLine("),") + .Append("typeof(") + .AppendIf("global::", !decoratorType.StartsWith("global::")) + .Append(decoratorType) + .AppendLine(")") + .DecrementIndent() + .AppendLine(");") + .AppendLine(); + } + else if (isKeyed) + { + codeBuilder + .Append("global::Injectio.Internal.InjectioDecorationExtensions.DecorateKeyed<") + .AppendIf("global::", !serviceType.StartsWith("global::")) + .Append(serviceType) + .AppendLine(">(") + .IncrementIndent() + .AppendLine("serviceCollection,") + .Append(keyExpression) + .AppendLine(","); + + WriteDecoratorFactory(codeBuilder, decorator, isKeyed: true); + + codeBuilder + .AppendLine() + .DecrementIndent() + .AppendLine(");") + .AppendLine(); + } + else + { + codeBuilder + .Append("global::Injectio.Internal.InjectioDecorationExtensions.Decorate<") + .AppendIf("global::", !serviceType.StartsWith("global::")) + .Append(serviceType) + .AppendLine(">(") + .IncrementIndent() + .AppendLine("serviceCollection,"); + + WriteDecoratorFactory(codeBuilder, decorator, isKeyed: false); + + codeBuilder + .AppendLine() + .DecrementIndent() + .AppendLine(");") + .AppendLine(); + } + + if (decorator.Tags.Count > 0) + { + codeBuilder + .DecrementIndent() + .AppendLine("}") + .AppendLine(); + } + } + + private static void WriteDecoratorFactory( + IndentedStringBuilder codeBuilder, + DecoratorRegistration decorator, + bool isKeyed) + { + var serviceType = decorator.ServiceType; + var decoratorType = decorator.DecoratorType; + var qualifiedService = serviceType.StartsWith("global::") ? serviceType : "global::" + serviceType; + var qualifiedDecorator = decoratorType.StartsWith("global::") ? decoratorType : "global::" + decoratorType; + + if (decorator.Factory.HasValue()) + { + bool hasNamespace = decorator.Factory!.Contains("."); + var factoryTarget = hasNamespace ? decorator.Factory! : qualifiedDecorator + "." + decorator.Factory; + + if (isKeyed) + { + codeBuilder + .Append("static (serviceProvider, serviceKey, inner) => ") + .Append(factoryTarget) + .Append("(serviceProvider, serviceKey, inner)"); + } + else + { + codeBuilder + .Append("static (serviceProvider, inner) => ") + .Append(factoryTarget) + .Append("(serviceProvider, inner)"); + } + } + else + { + if (isKeyed) + { + codeBuilder + .Append("static (serviceProvider, serviceKey, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance<") + .Append(qualifiedDecorator) + .Append(">(serviceProvider, inner)"); + } + else + { + codeBuilder + .Append("static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance<") + .Append(qualifiedDecorator) + .Append(">(serviceProvider, inner)"); + } + } + } + + public static string GenerateDecorationHelper() + { + var codeBuilder = new IndentedStringBuilder(); + codeBuilder + .AppendLine("// ") + .AppendLine("#nullable enable") + .AppendLine() + .AppendLine("namespace Injectio.Internal") + .AppendLine("{") + .IncrementIndent() + .Append("[global::System.CodeDom.Compiler.GeneratedCodeAttribute(\"") + .Append(ThisAssembly.Product) + .Append("\", \"") + .Append(ThisAssembly.InformationalVersion) + .AppendLine("\")]") + .AppendLine("internal static class InjectioDecorationExtensions") + .AppendLine("{") + .IncrementIndent(); + + codeBuilder.AppendLines(DecorationHelperBody, skipFinalNewline: true); + + codeBuilder + .AppendLine() + .DecrementIndent() + .AppendLine("}") + .DecrementIndent() + .AppendLine("}"); + + return codeBuilder.ToString(); + } + + private const string DecorationHelperBody = """ +internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Func decoratorFactory) + where TService : class +{ + var serviceType = typeof(TService); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (IsKeyedDescriptor(descriptor)) + continue; + + var inner = CreateInnerFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => decoratorFactory(sp, (TService)inner(sp))!, + lifetime); + } + + return services; +} + +#if NET8_0_OR_GREATER +internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateKeyed( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + object? serviceKey, + global::System.Func decoratorFactory) + where TService : class +{ + var serviceType = typeof(TService); + bool anyKey = ReferenceEquals(serviceKey, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (!descriptor.IsKeyedService) + continue; + + if (!anyKey && !global::System.Collections.Generic.EqualityComparer.Default.Equals(descriptor.ServiceKey, serviceKey)) + continue; + + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => decoratorFactory(sp, key, (TService)innerKeyed(sp, key))!, + lifetime); + } + + return services; +} +#endif + +internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateOpenGeneric( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Type openServiceType, + global::System.Type openDecoratorType) +{ + if (!openServiceType.IsGenericTypeDefinition) + openServiceType = openServiceType.GetGenericTypeDefinition(); + + if (!openDecoratorType.IsGenericTypeDefinition) + openDecoratorType = openDecoratorType.GetGenericTypeDefinition(); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + var serviceType = descriptor.ServiceType; + + if (!serviceType.IsGenericType) + continue; + + // skip truly open-generic descriptors; MS.DI forbids factory on open service types + // and replacing the implementation type would cause recursive resolution. + if (serviceType.IsGenericTypeDefinition) + continue; + + if (serviceType.GetGenericTypeDefinition() != openServiceType) + continue; + + var typeArgs = serviceType.GetGenericArguments(); + global::System.Type closedDecoratorType; + try + { + closedDecoratorType = openDecoratorType.MakeGenericType(typeArgs); + } + catch (global::System.ArgumentException) + { + continue; + } + +#if NET8_0_OR_GREATER + if (descriptor.IsKeyedService) + { + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, innerKeyed(sp, key))!, + lifetime); + continue; + } +#endif + + var inner = CreateInnerFactory(descriptor); + var nonKeyedLifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, inner(sp))!, + nonKeyedLifetime); + } + + return services; +} + +private static bool IsKeyedDescriptor(global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) +{ +#if NET8_0_OR_GREATER + return descriptor.IsKeyedService; +#else + return false; +#endif +} + +private static global::System.Func CreateInnerFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) +{ + if (descriptor.ImplementationInstance is object instance) + return _ => instance; + + if (descriptor.ImplementationFactory is global::System.Func factory) + return factory; + + var implementationType = descriptor.ImplementationType ?? descriptor.ServiceType; + return sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); +} + +#if NET8_0_OR_GREATER +private static global::System.Func CreateInnerKeyedFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) +{ + if (descriptor.KeyedImplementationInstance is object keyedInstance) + return (_, _) => keyedInstance; + + if (descriptor.KeyedImplementationFactory is global::System.Func keyedFactory) + return keyedFactory; + + var implementationType = descriptor.KeyedImplementationType ?? descriptor.ServiceType; + return (sp, _) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); +} +#endif +"""; + public static string GetServiceCollectionMethod(string duplicateStrategy) { return duplicateStrategy switch diff --git a/src/Injectio.Generators/SymbolHelpers.cs b/src/Injectio.Generators/SymbolHelpers.cs index 54cc867..04cfc4d 100644 --- a/src/Injectio.Generators/SymbolHelpers.cs +++ b/src/Injectio.Generators/SymbolHelpers.cs @@ -61,6 +61,19 @@ public static bool IsScopedAttribute(AttributeData attribute) }; } + public static bool IsDecoratorAttribute(AttributeData attribute) + { + return attribute?.AttributeClass is + { + Name: KnownTypes.DecoratorAttributeShortName or KnownTypes.DecoratorAttributeTypeName, + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace.Name: "Injectio" + } + }; + } + public static bool IsKnownAttribute(AttributeData attribute, out string serviceLifetime) { if (IsSingletonAttribute(attribute)) diff --git a/tests/Injectio.Acceptance.Tests/DecoratorTests.cs b/tests/Injectio.Acceptance.Tests/DecoratorTests.cs new file mode 100644 index 0000000..c16818d --- /dev/null +++ b/tests/Injectio.Acceptance.Tests/DecoratorTests.cs @@ -0,0 +1,48 @@ +using AwesomeAssertions; + +using Injectio.Acceptance.Tests.Services; + +using Microsoft.Extensions.DependencyInjection; + +namespace Injectio.Acceptance.Tests; + +[Collection(DependencyInjectionCollection.CollectionName)] +public class DecoratorTests(DependencyInjectionFixture fixture) : DependencyInjectionBase(fixture) +{ + [Fact] + public void ShouldResolveChainedDecoratorsInOrder() + { + var greeter = Services.GetRequiredService(); + + greeter.Should().BeOfType(); + greeter.Greet().Should().Be("caching(logging(base))"); + + var caching = (CachingGreeter)greeter; + caching.Inner.Should().BeOfType(); + + var logging = (LoggingGreeter)caching.Inner; + logging.Inner.Should().BeOfType(); + } + + [Fact] + public void ShouldDecorateClosedGenericViaOpenDecorator() + { + var repo = Services.GetRequiredService>(); + + repo.Should().BeOfType>(); + repo.Describe().Should().Be("logging(repo)"); + } + + [Fact] + public void ShouldDecorateEveryKeyedVariantWhenAnyKey() + { + var alpha = Services.GetRequiredKeyedService("alpha"); + var beta = Services.GetRequiredKeyedService("beta"); + + alpha.Should().BeOfType(); + alpha.Name.Should().Be("wrapped(alpha)"); + + beta.Should().BeOfType(); + beta.Name.Should().Be("wrapped(beta)"); + } +} diff --git a/tests/Injectio.Acceptance.Tests/Services/DecoratorService.cs b/tests/Injectio.Acceptance.Tests/Services/DecoratorService.cs new file mode 100644 index 0000000..a428490 --- /dev/null +++ b/tests/Injectio.Acceptance.Tests/Services/DecoratorService.cs @@ -0,0 +1,82 @@ +using Injectio.Attributes; + +namespace Injectio.Acceptance.Tests.Services; + +public interface IGreeter +{ + string Greet(); +} + +[RegisterSingleton] +public class BaseGreeter : IGreeter +{ + public string Greet() => "base"; +} + +[RegisterDecorator(Order = 1)] +public class LoggingGreeter : IGreeter +{ + public IGreeter Inner { get; } + + public LoggingGreeter(IGreeter inner) => Inner = inner; + + public string Greet() => $"logging({Inner.Greet()})"; +} + +[RegisterDecorator(Order = 2)] +public class CachingGreeter : IGreeter +{ + public IGreeter Inner { get; } + + public CachingGreeter(IGreeter inner) => Inner = inner; + + public string Greet() => $"caching({Inner.Greet()})"; +} + +public interface IRepo +{ + string Describe(); +} + +[RegisterSingleton, StringRepo>] +public class StringRepo : IRepo +{ + public string Describe() => "repo"; +} + +[RegisterDecorator(ServiceType = typeof(IRepo<>))] +public class LoggingRepo : IRepo +{ + public IRepo Inner { get; } + + public LoggingRepo(IRepo inner) => Inner = inner; + + public string Describe() => $"logging({Inner.Describe()})"; +} + +public interface IKeyedThing +{ + string Name { get; } +} + +[RegisterSingleton(ServiceKey = "alpha")] +public class AlphaThing : IKeyedThing +{ + public string Name => "alpha"; +} + +[RegisterSingleton(ServiceKey = "beta")] +public class BetaThing : IKeyedThing +{ + public string Name => "beta"; +} + +[RegisterDecorator(AnyKey = true)] +public class WrappedThing : IKeyedThing +{ + public IKeyedThing Inner { get; } + + public WrappedThing(IKeyedThing inner) => Inner = inner; + + public string Name => $"wrapped({Inner.Name})"; +} diff --git a/tests/Injectio.Tests/ServiceRegistrationDecoratorTests.cs b/tests/Injectio.Tests/ServiceRegistrationDecoratorTests.cs new file mode 100644 index 0000000..39d344b --- /dev/null +++ b/tests/Injectio.Tests/ServiceRegistrationDecoratorTests.cs @@ -0,0 +1,460 @@ +using System; +using System.Collections.Immutable; +using System.Linq; +using System.Threading.Tasks; + +using AwesomeAssertions; + +using Injectio.Attributes; +using Injectio.Generators; + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.Extensions.DependencyInjection; + +using VerifyXunit; + +using Xunit; + +namespace Injectio.Tests; + +public class ServiceRegistrationDecoratorTests +{ + [Fact] + public Task GenerateDecoratorSimple() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator(ServiceType = typeof(IService))] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorGenericAttribute() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorChainedOrder() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator(Order = 2)] + public class CachingDecorator : IService + { + public CachingDecorator(IService inner) { } + } + + [RegisterDecorator(Order = 1)] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorKeyed() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton(ServiceKey = "Alpha")] + public class Service : IService { } + + [RegisterDecorator(ServiceKey = "Alpha")] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorAnyKey() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton(ServiceKey = "Alpha")] + public class ServiceA : IService { } + + [RegisterSingleton(ServiceKey = "Beta")] + public class ServiceB : IService { } + + [RegisterDecorator(AnyKey = true)] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorFactory() + { + const string source = """ + using System; + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator(Factory = nameof(Create))] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + + public static IService Create(IServiceProvider serviceProvider, IService inner) + => new LoggingDecorator(inner); + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorOpenGeneric() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IRepo { } + + [RegisterSingleton(ServiceType = typeof(IRepo<>), ImplementationType = typeof(Repo<>))] + public class Repo : IRepo { } + + [RegisterDecorator(ServiceType = typeof(IRepo<>))] + public class LoggingRepo : IRepo + { + public LoggingRepo(IRepo inner) { } + } + """; + + return Verify(source); + } + + [Fact] + public Task GenerateDecoratorTags() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton(Tags = "FrontEnd")] + public class Service : IService { } + + [RegisterDecorator(Tags = "FrontEnd")] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + return Verify(source); + } + + // ------- Diagnostics ------- + + [Fact] + public async Task DiagnoseDecoratorDoesNotImplementService() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + public interface IOther { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator(ServiceType = typeof(IService))] + public class BadDecorator : IOther { } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().Contain(d => d.Id == "INJ0010"); + } + + [Fact] + public async Task DiagnoseDecoratorMissingServiceType() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().Contain(d => d.Id == "INJ0011"); + } + + [Fact] + public async Task DiagnoseDecoratorFactoryNotFound() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator(Factory = "Missing")] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().Contain(d => d.Id == "INJ0013"); + } + + [Fact] + public async Task DiagnoseDecoratorFactoryInvalidSignature() + { + const string source = """ + using System; + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator(Factory = nameof(Create))] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + + public IService Create() => this; + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().Contain(d => d.Id == "INJ0014"); + } + + [Fact] + public async Task DiagnoseDecoratorOpenGenericKeyed() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IRepo { } + + [RegisterSingleton(ServiceType = typeof(IRepo<>), ImplementationType = typeof(Repo<>))] + public class Repo : IRepo { } + + [RegisterDecorator(ServiceType = typeof(IRepo<>), ServiceKey = "X")] + public class LoggingRepo : IRepo + { + public LoggingRepo(IRepo inner) { } + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().Contain(d => d.Id == "INJ0015"); + } + + [Fact] + public async Task DiagnoseDecoratorTargetNotRegistered() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterDecorator] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().Contain(d => d.Id == "INJ0016"); + } + + [Fact] + public async Task NoDiagnosticForValidDecorator() + { + const string source = """ + using Injectio.Attributes; + + namespace Injectio.Sample; + + public interface IService { } + + [RegisterSingleton] + public class Service : IService { } + + [RegisterDecorator] + public class LoggingDecorator : IService + { + public LoggingDecorator(IService inner) { } + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + diagnostics.Should().BeEmpty(); + } + + private static Task Verify(string source) + { + var output = GetAllGeneratedOutput(source); + + return Verifier + .Verify(output) + .UseDirectory("Snapshots") + .ScrubLinesContaining("GeneratedCodeAttribute"); + } + + private static string GetAllGeneratedOutput(string source) + where T : IIncrementalGenerator, new() + { + var syntaxTree = CSharpSyntaxTree.ParseText(source); + var references = AppDomain.CurrentDomain.GetAssemblies() + .Where(assembly => !assembly.IsDynamic && !string.IsNullOrWhiteSpace(assembly.Location)) + .Select(assembly => MetadataReference.CreateFromFile(assembly.Location)) + .Concat(new[] + { + MetadataReference.CreateFromFile(typeof(T).Assembly.Location), + MetadataReference.CreateFromFile(typeof(RegisterServicesAttribute).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location), + }); + + var compilation = CSharpCompilation.Create( + "Test.Generator", + new[] { syntaxTree }, + references, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + var originalTreeCount = compilation.SyntaxTrees.Length; + var generator = new T(); + + var driver = CSharpGeneratorDriver.Create(generator); + driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out _); + + var generated = outputCompilation.SyntaxTrees + .Skip(originalTreeCount) + .Select(t => $"// {System.IO.Path.GetFileName(t.FilePath)}\n{t}") + .ToArray(); + + return string.Join("\n\n// ==========\n\n", generated); + } + + private static async Task> GetDiagnosticsAsync(string source) + { + var syntaxTree = CSharpSyntaxTree.ParseText(source); + var references = AppDomain.CurrentDomain.GetAssemblies() + .Where(assembly => !assembly.IsDynamic && !string.IsNullOrWhiteSpace(assembly.Location)) + .Select(assembly => MetadataReference.CreateFromFile(assembly.Location)) + .Concat(new[] + { + MetadataReference.CreateFromFile(typeof(ServiceRegistrationGenerator).Assembly.Location), + MetadataReference.CreateFromFile(typeof(RegisterServicesAttribute).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location), + }); + + var compilation = CSharpCompilation.Create( + "Test.Diagnostics", + new[] { syntaxTree }, + references, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + var analyzer = new ServiceRegistrationAnalyzer(); + var compilationWithAnalyzers = compilation.WithAnalyzers(ImmutableArray.Create(analyzer)); + var diagnostics = await compilationWithAnalyzers.GetAnalyzerDiagnosticsAsync(); + + return diagnostics + .Where(d => d.Id.StartsWith("INJ")) + .ToImmutableArray(); + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorAnyKey.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorAnyKey.verified.txt new file mode 100644 index 0000000..91e6d55 --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorAnyKey.verified.txt @@ -0,0 +1,222 @@ +// Injectio.g.cs +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Alpha") + ); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Beta") + ); + + global::Injectio.Internal.InjectioDecorationExtensions.DecorateKeyed( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey, + static (serviceProvider, serviceKey, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + return serviceCollection; + } + } +} + + +// ========== + +// Injectio.Decoration.g.cs +// +#nullable enable + +namespace Injectio.Internal +{ + internal static class InjectioDecorationExtensions + { + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (IsKeyedDescriptor(descriptor)) + continue; + + var inner = CreateInnerFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => decoratorFactory(sp, (TService)inner(sp))!, + lifetime); + } + + return services; + } + + #if NET8_0_OR_GREATER + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateKeyed( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + object? serviceKey, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + bool anyKey = ReferenceEquals(serviceKey, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (!descriptor.IsKeyedService) + continue; + + if (!anyKey && !global::System.Collections.Generic.EqualityComparer.Default.Equals(descriptor.ServiceKey, serviceKey)) + continue; + + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => decoratorFactory(sp, key, (TService)innerKeyed(sp, key))!, + lifetime); + } + + return services; + } + #endif + + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateOpenGeneric( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Type openServiceType, + global::System.Type openDecoratorType) + { + if (!openServiceType.IsGenericTypeDefinition) + openServiceType = openServiceType.GetGenericTypeDefinition(); + + if (!openDecoratorType.IsGenericTypeDefinition) + openDecoratorType = openDecoratorType.GetGenericTypeDefinition(); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + var serviceType = descriptor.ServiceType; + + if (!serviceType.IsGenericType) + continue; + + // skip truly open-generic descriptors; MS.DI forbids factory on open service types + // and replacing the implementation type would cause recursive resolution. + if (serviceType.IsGenericTypeDefinition) + continue; + + if (serviceType.GetGenericTypeDefinition() != openServiceType) + continue; + + var typeArgs = serviceType.GetGenericArguments(); + global::System.Type closedDecoratorType; + try + { + closedDecoratorType = openDecoratorType.MakeGenericType(typeArgs); + } + catch (global::System.ArgumentException) + { + continue; + } + + #if NET8_0_OR_GREATER + if (descriptor.IsKeyedService) + { + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, innerKeyed(sp, key))!, + lifetime); + continue; + } + #endif + + var inner = CreateInnerFactory(descriptor); + var nonKeyedLifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, inner(sp))!, + nonKeyedLifetime); + } + + return services; + } + + private static bool IsKeyedDescriptor(global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + #if NET8_0_OR_GREATER + return descriptor.IsKeyedService; + #else + return false; + #endif + } + + private static global::System.Func CreateInnerFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.ImplementationInstance is object instance) + return _ => instance; + + if (descriptor.ImplementationFactory is global::System.Func factory) + return factory; + + var implementationType = descriptor.ImplementationType ?? descriptor.ServiceType; + return sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + + #if NET8_0_OR_GREATER + private static global::System.Func CreateInnerKeyedFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.KeyedImplementationInstance is object keyedInstance) + return (_, _) => keyedInstance; + + if (descriptor.KeyedImplementationFactory is global::System.Func keyedFactory) + return keyedFactory; + + var implementationType = descriptor.KeyedImplementationType ?? descriptor.ServiceType; + return (sp, _) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + #endif + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorChainedOrder.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorChainedOrder.verified.txt new file mode 100644 index 0000000..48ac6a0 --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorChainedOrder.verified.txt @@ -0,0 +1,228 @@ +// Injectio.g.cs +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton( + (serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider) + ) + ); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton() + ); + + global::Injectio.Internal.InjectioDecorationExtensions.Decorate( + serviceCollection, + static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + global::Injectio.Internal.InjectioDecorationExtensions.Decorate( + serviceCollection, + static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + return serviceCollection; + } + } +} + + +// ========== + +// Injectio.Decoration.g.cs +// +#nullable enable + +namespace Injectio.Internal +{ + internal static class InjectioDecorationExtensions + { + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (IsKeyedDescriptor(descriptor)) + continue; + + var inner = CreateInnerFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => decoratorFactory(sp, (TService)inner(sp))!, + lifetime); + } + + return services; + } + + #if NET8_0_OR_GREATER + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateKeyed( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + object? serviceKey, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + bool anyKey = ReferenceEquals(serviceKey, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (!descriptor.IsKeyedService) + continue; + + if (!anyKey && !global::System.Collections.Generic.EqualityComparer.Default.Equals(descriptor.ServiceKey, serviceKey)) + continue; + + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => decoratorFactory(sp, key, (TService)innerKeyed(sp, key))!, + lifetime); + } + + return services; + } + #endif + + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateOpenGeneric( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Type openServiceType, + global::System.Type openDecoratorType) + { + if (!openServiceType.IsGenericTypeDefinition) + openServiceType = openServiceType.GetGenericTypeDefinition(); + + if (!openDecoratorType.IsGenericTypeDefinition) + openDecoratorType = openDecoratorType.GetGenericTypeDefinition(); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + var serviceType = descriptor.ServiceType; + + if (!serviceType.IsGenericType) + continue; + + // skip truly open-generic descriptors; MS.DI forbids factory on open service types + // and replacing the implementation type would cause recursive resolution. + if (serviceType.IsGenericTypeDefinition) + continue; + + if (serviceType.GetGenericTypeDefinition() != openServiceType) + continue; + + var typeArgs = serviceType.GetGenericArguments(); + global::System.Type closedDecoratorType; + try + { + closedDecoratorType = openDecoratorType.MakeGenericType(typeArgs); + } + catch (global::System.ArgumentException) + { + continue; + } + + #if NET8_0_OR_GREATER + if (descriptor.IsKeyedService) + { + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, innerKeyed(sp, key))!, + lifetime); + continue; + } + #endif + + var inner = CreateInnerFactory(descriptor); + var nonKeyedLifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, inner(sp))!, + nonKeyedLifetime); + } + + return services; + } + + private static bool IsKeyedDescriptor(global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + #if NET8_0_OR_GREATER + return descriptor.IsKeyedService; + #else + return false; + #endif + } + + private static global::System.Func CreateInnerFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.ImplementationInstance is object instance) + return _ => instance; + + if (descriptor.ImplementationFactory is global::System.Func factory) + return factory; + + var implementationType = descriptor.ImplementationType ?? descriptor.ServiceType; + return sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + + #if NET8_0_OR_GREATER + private static global::System.Func CreateInnerKeyedFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.KeyedImplementationInstance is object keyedInstance) + return (_, _) => keyedInstance; + + if (descriptor.KeyedImplementationFactory is global::System.Func keyedFactory) + return keyedFactory; + + var implementationType = descriptor.KeyedImplementationType ?? descriptor.ServiceType; + return (sp, _) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + #endif + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorFactory.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorFactory.verified.txt new file mode 100644 index 0000000..7bb5a0a --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorFactory.verified.txt @@ -0,0 +1,223 @@ +// Injectio.g.cs +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton( + (serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider) + ) + ); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton() + ); + + global::Injectio.Internal.InjectioDecorationExtensions.Decorate( + serviceCollection, + static (serviceProvider, inner) => global::Injectio.Sample.LoggingDecorator.Create(serviceProvider, inner) + ); + + return serviceCollection; + } + } +} + + +// ========== + +// Injectio.Decoration.g.cs +// +#nullable enable + +namespace Injectio.Internal +{ + internal static class InjectioDecorationExtensions + { + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (IsKeyedDescriptor(descriptor)) + continue; + + var inner = CreateInnerFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => decoratorFactory(sp, (TService)inner(sp))!, + lifetime); + } + + return services; + } + + #if NET8_0_OR_GREATER + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateKeyed( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + object? serviceKey, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + bool anyKey = ReferenceEquals(serviceKey, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (!descriptor.IsKeyedService) + continue; + + if (!anyKey && !global::System.Collections.Generic.EqualityComparer.Default.Equals(descriptor.ServiceKey, serviceKey)) + continue; + + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => decoratorFactory(sp, key, (TService)innerKeyed(sp, key))!, + lifetime); + } + + return services; + } + #endif + + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateOpenGeneric( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Type openServiceType, + global::System.Type openDecoratorType) + { + if (!openServiceType.IsGenericTypeDefinition) + openServiceType = openServiceType.GetGenericTypeDefinition(); + + if (!openDecoratorType.IsGenericTypeDefinition) + openDecoratorType = openDecoratorType.GetGenericTypeDefinition(); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + var serviceType = descriptor.ServiceType; + + if (!serviceType.IsGenericType) + continue; + + // skip truly open-generic descriptors; MS.DI forbids factory on open service types + // and replacing the implementation type would cause recursive resolution. + if (serviceType.IsGenericTypeDefinition) + continue; + + if (serviceType.GetGenericTypeDefinition() != openServiceType) + continue; + + var typeArgs = serviceType.GetGenericArguments(); + global::System.Type closedDecoratorType; + try + { + closedDecoratorType = openDecoratorType.MakeGenericType(typeArgs); + } + catch (global::System.ArgumentException) + { + continue; + } + + #if NET8_0_OR_GREATER + if (descriptor.IsKeyedService) + { + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, innerKeyed(sp, key))!, + lifetime); + continue; + } + #endif + + var inner = CreateInnerFactory(descriptor); + var nonKeyedLifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, inner(sp))!, + nonKeyedLifetime); + } + + return services; + } + + private static bool IsKeyedDescriptor(global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + #if NET8_0_OR_GREATER + return descriptor.IsKeyedService; + #else + return false; + #endif + } + + private static global::System.Func CreateInnerFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.ImplementationInstance is object instance) + return _ => instance; + + if (descriptor.ImplementationFactory is global::System.Func factory) + return factory; + + var implementationType = descriptor.ImplementationType ?? descriptor.ServiceType; + return sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + + #if NET8_0_OR_GREATER + private static global::System.Func CreateInnerKeyedFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.KeyedImplementationInstance is object keyedInstance) + return (_, _) => keyedInstance; + + if (descriptor.KeyedImplementationFactory is global::System.Func keyedFactory) + return keyedFactory; + + var implementationType = descriptor.KeyedImplementationType ?? descriptor.ServiceType; + return (sp, _) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + #endif + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorGenericAttribute.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorGenericAttribute.verified.txt new file mode 100644 index 0000000..123cbae --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorGenericAttribute.verified.txt @@ -0,0 +1,223 @@ +// Injectio.g.cs +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton( + (serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider) + ) + ); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton() + ); + + global::Injectio.Internal.InjectioDecorationExtensions.Decorate( + serviceCollection, + static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + return serviceCollection; + } + } +} + + +// ========== + +// Injectio.Decoration.g.cs +// +#nullable enable + +namespace Injectio.Internal +{ + internal static class InjectioDecorationExtensions + { + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (IsKeyedDescriptor(descriptor)) + continue; + + var inner = CreateInnerFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => decoratorFactory(sp, (TService)inner(sp))!, + lifetime); + } + + return services; + } + + #if NET8_0_OR_GREATER + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateKeyed( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + object? serviceKey, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + bool anyKey = ReferenceEquals(serviceKey, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (!descriptor.IsKeyedService) + continue; + + if (!anyKey && !global::System.Collections.Generic.EqualityComparer.Default.Equals(descriptor.ServiceKey, serviceKey)) + continue; + + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => decoratorFactory(sp, key, (TService)innerKeyed(sp, key))!, + lifetime); + } + + return services; + } + #endif + + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateOpenGeneric( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Type openServiceType, + global::System.Type openDecoratorType) + { + if (!openServiceType.IsGenericTypeDefinition) + openServiceType = openServiceType.GetGenericTypeDefinition(); + + if (!openDecoratorType.IsGenericTypeDefinition) + openDecoratorType = openDecoratorType.GetGenericTypeDefinition(); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + var serviceType = descriptor.ServiceType; + + if (!serviceType.IsGenericType) + continue; + + // skip truly open-generic descriptors; MS.DI forbids factory on open service types + // and replacing the implementation type would cause recursive resolution. + if (serviceType.IsGenericTypeDefinition) + continue; + + if (serviceType.GetGenericTypeDefinition() != openServiceType) + continue; + + var typeArgs = serviceType.GetGenericArguments(); + global::System.Type closedDecoratorType; + try + { + closedDecoratorType = openDecoratorType.MakeGenericType(typeArgs); + } + catch (global::System.ArgumentException) + { + continue; + } + + #if NET8_0_OR_GREATER + if (descriptor.IsKeyedService) + { + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, innerKeyed(sp, key))!, + lifetime); + continue; + } + #endif + + var inner = CreateInnerFactory(descriptor); + var nonKeyedLifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, inner(sp))!, + nonKeyedLifetime); + } + + return services; + } + + private static bool IsKeyedDescriptor(global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + #if NET8_0_OR_GREATER + return descriptor.IsKeyedService; + #else + return false; + #endif + } + + private static global::System.Func CreateInnerFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.ImplementationInstance is object instance) + return _ => instance; + + if (descriptor.ImplementationFactory is global::System.Func factory) + return factory; + + var implementationType = descriptor.ImplementationType ?? descriptor.ServiceType; + return sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + + #if NET8_0_OR_GREATER + private static global::System.Func CreateInnerKeyedFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.KeyedImplementationInstance is object keyedInstance) + return (_, _) => keyedInstance; + + if (descriptor.KeyedImplementationFactory is global::System.Func keyedFactory) + return keyedFactory; + + var implementationType = descriptor.KeyedImplementationType ?? descriptor.ServiceType; + return (sp, _) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + #endif + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorKeyed.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorKeyed.verified.txt new file mode 100644 index 0000000..c0f8db9 --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorKeyed.verified.txt @@ -0,0 +1,217 @@ +// Injectio.g.cs +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.KeyedSingleton("Alpha") + ); + + global::Injectio.Internal.InjectioDecorationExtensions.DecorateKeyed( + serviceCollection, + "Alpha", + static (serviceProvider, serviceKey, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + return serviceCollection; + } + } +} + + +// ========== + +// Injectio.Decoration.g.cs +// +#nullable enable + +namespace Injectio.Internal +{ + internal static class InjectioDecorationExtensions + { + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (IsKeyedDescriptor(descriptor)) + continue; + + var inner = CreateInnerFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => decoratorFactory(sp, (TService)inner(sp))!, + lifetime); + } + + return services; + } + + #if NET8_0_OR_GREATER + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateKeyed( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + object? serviceKey, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + bool anyKey = ReferenceEquals(serviceKey, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (!descriptor.IsKeyedService) + continue; + + if (!anyKey && !global::System.Collections.Generic.EqualityComparer.Default.Equals(descriptor.ServiceKey, serviceKey)) + continue; + + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => decoratorFactory(sp, key, (TService)innerKeyed(sp, key))!, + lifetime); + } + + return services; + } + #endif + + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateOpenGeneric( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Type openServiceType, + global::System.Type openDecoratorType) + { + if (!openServiceType.IsGenericTypeDefinition) + openServiceType = openServiceType.GetGenericTypeDefinition(); + + if (!openDecoratorType.IsGenericTypeDefinition) + openDecoratorType = openDecoratorType.GetGenericTypeDefinition(); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + var serviceType = descriptor.ServiceType; + + if (!serviceType.IsGenericType) + continue; + + // skip truly open-generic descriptors; MS.DI forbids factory on open service types + // and replacing the implementation type would cause recursive resolution. + if (serviceType.IsGenericTypeDefinition) + continue; + + if (serviceType.GetGenericTypeDefinition() != openServiceType) + continue; + + var typeArgs = serviceType.GetGenericArguments(); + global::System.Type closedDecoratorType; + try + { + closedDecoratorType = openDecoratorType.MakeGenericType(typeArgs); + } + catch (global::System.ArgumentException) + { + continue; + } + + #if NET8_0_OR_GREATER + if (descriptor.IsKeyedService) + { + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, innerKeyed(sp, key))!, + lifetime); + continue; + } + #endif + + var inner = CreateInnerFactory(descriptor); + var nonKeyedLifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, inner(sp))!, + nonKeyedLifetime); + } + + return services; + } + + private static bool IsKeyedDescriptor(global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + #if NET8_0_OR_GREATER + return descriptor.IsKeyedService; + #else + return false; + #endif + } + + private static global::System.Func CreateInnerFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.ImplementationInstance is object instance) + return _ => instance; + + if (descriptor.ImplementationFactory is global::System.Func factory) + return factory; + + var implementationType = descriptor.ImplementationType ?? descriptor.ServiceType; + return sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + + #if NET8_0_OR_GREATER + private static global::System.Func CreateInnerKeyedFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.KeyedImplementationInstance is object keyedInstance) + return (_, _) => keyedInstance; + + if (descriptor.KeyedImplementationFactory is global::System.Func keyedFactory) + return keyedFactory; + + var implementationType = descriptor.KeyedImplementationType ?? descriptor.ServiceType; + return (sp, _) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + #endif + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorOpenGeneric.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorOpenGeneric.verified.txt new file mode 100644 index 0000000..63e398f --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorOpenGeneric.verified.txt @@ -0,0 +1,220 @@ +// Injectio.g.cs +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton( + typeof(global::Injectio.Sample.IRepo<>), + typeof(global::Injectio.Sample.Repo<>) + ) + ); + + global::Injectio.Internal.InjectioDecorationExtensions.DecorateOpenGeneric( + serviceCollection, + typeof(global::Injectio.Sample.IRepo<>), + typeof(global::Injectio.Sample.LoggingRepo<>) + ); + + return serviceCollection; + } + } +} + + +// ========== + +// Injectio.Decoration.g.cs +// +#nullable enable + +namespace Injectio.Internal +{ + internal static class InjectioDecorationExtensions + { + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (IsKeyedDescriptor(descriptor)) + continue; + + var inner = CreateInnerFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => decoratorFactory(sp, (TService)inner(sp))!, + lifetime); + } + + return services; + } + + #if NET8_0_OR_GREATER + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateKeyed( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + object? serviceKey, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + bool anyKey = ReferenceEquals(serviceKey, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (!descriptor.IsKeyedService) + continue; + + if (!anyKey && !global::System.Collections.Generic.EqualityComparer.Default.Equals(descriptor.ServiceKey, serviceKey)) + continue; + + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => decoratorFactory(sp, key, (TService)innerKeyed(sp, key))!, + lifetime); + } + + return services; + } + #endif + + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateOpenGeneric( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Type openServiceType, + global::System.Type openDecoratorType) + { + if (!openServiceType.IsGenericTypeDefinition) + openServiceType = openServiceType.GetGenericTypeDefinition(); + + if (!openDecoratorType.IsGenericTypeDefinition) + openDecoratorType = openDecoratorType.GetGenericTypeDefinition(); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + var serviceType = descriptor.ServiceType; + + if (!serviceType.IsGenericType) + continue; + + // skip truly open-generic descriptors; MS.DI forbids factory on open service types + // and replacing the implementation type would cause recursive resolution. + if (serviceType.IsGenericTypeDefinition) + continue; + + if (serviceType.GetGenericTypeDefinition() != openServiceType) + continue; + + var typeArgs = serviceType.GetGenericArguments(); + global::System.Type closedDecoratorType; + try + { + closedDecoratorType = openDecoratorType.MakeGenericType(typeArgs); + } + catch (global::System.ArgumentException) + { + continue; + } + + #if NET8_0_OR_GREATER + if (descriptor.IsKeyedService) + { + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, innerKeyed(sp, key))!, + lifetime); + continue; + } + #endif + + var inner = CreateInnerFactory(descriptor); + var nonKeyedLifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, inner(sp))!, + nonKeyedLifetime); + } + + return services; + } + + private static bool IsKeyedDescriptor(global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + #if NET8_0_OR_GREATER + return descriptor.IsKeyedService; + #else + return false; + #endif + } + + private static global::System.Func CreateInnerFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.ImplementationInstance is object instance) + return _ => instance; + + if (descriptor.ImplementationFactory is global::System.Func factory) + return factory; + + var implementationType = descriptor.ImplementationType ?? descriptor.ServiceType; + return sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + + #if NET8_0_OR_GREATER + private static global::System.Func CreateInnerKeyedFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.KeyedImplementationInstance is object keyedInstance) + return (_, _) => keyedInstance; + + if (descriptor.KeyedImplementationFactory is global::System.Func keyedFactory) + return keyedFactory; + + var implementationType = descriptor.KeyedImplementationType ?? descriptor.ServiceType; + return (sp, _) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + #endif + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorSimple.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorSimple.verified.txt new file mode 100644 index 0000000..123cbae --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorSimple.verified.txt @@ -0,0 +1,223 @@ +// Injectio.g.cs +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton( + (serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider) + ) + ); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton() + ); + + global::Injectio.Internal.InjectioDecorationExtensions.Decorate( + serviceCollection, + static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + return serviceCollection; + } + } +} + + +// ========== + +// Injectio.Decoration.g.cs +// +#nullable enable + +namespace Injectio.Internal +{ + internal static class InjectioDecorationExtensions + { + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (IsKeyedDescriptor(descriptor)) + continue; + + var inner = CreateInnerFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => decoratorFactory(sp, (TService)inner(sp))!, + lifetime); + } + + return services; + } + + #if NET8_0_OR_GREATER + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateKeyed( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + object? serviceKey, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + bool anyKey = ReferenceEquals(serviceKey, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (!descriptor.IsKeyedService) + continue; + + if (!anyKey && !global::System.Collections.Generic.EqualityComparer.Default.Equals(descriptor.ServiceKey, serviceKey)) + continue; + + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => decoratorFactory(sp, key, (TService)innerKeyed(sp, key))!, + lifetime); + } + + return services; + } + #endif + + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateOpenGeneric( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Type openServiceType, + global::System.Type openDecoratorType) + { + if (!openServiceType.IsGenericTypeDefinition) + openServiceType = openServiceType.GetGenericTypeDefinition(); + + if (!openDecoratorType.IsGenericTypeDefinition) + openDecoratorType = openDecoratorType.GetGenericTypeDefinition(); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + var serviceType = descriptor.ServiceType; + + if (!serviceType.IsGenericType) + continue; + + // skip truly open-generic descriptors; MS.DI forbids factory on open service types + // and replacing the implementation type would cause recursive resolution. + if (serviceType.IsGenericTypeDefinition) + continue; + + if (serviceType.GetGenericTypeDefinition() != openServiceType) + continue; + + var typeArgs = serviceType.GetGenericArguments(); + global::System.Type closedDecoratorType; + try + { + closedDecoratorType = openDecoratorType.MakeGenericType(typeArgs); + } + catch (global::System.ArgumentException) + { + continue; + } + + #if NET8_0_OR_GREATER + if (descriptor.IsKeyedService) + { + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, innerKeyed(sp, key))!, + lifetime); + continue; + } + #endif + + var inner = CreateInnerFactory(descriptor); + var nonKeyedLifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, inner(sp))!, + nonKeyedLifetime); + } + + return services; + } + + private static bool IsKeyedDescriptor(global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + #if NET8_0_OR_GREATER + return descriptor.IsKeyedService; + #else + return false; + #endif + } + + private static global::System.Func CreateInnerFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.ImplementationInstance is object instance) + return _ => instance; + + if (descriptor.ImplementationFactory is global::System.Func factory) + return factory; + + var implementationType = descriptor.ImplementationType ?? descriptor.ServiceType; + return sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + + #if NET8_0_OR_GREATER + private static global::System.Func CreateInnerKeyedFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.KeyedImplementationInstance is object keyedInstance) + return (_, _) => keyedInstance; + + if (descriptor.KeyedImplementationFactory is global::System.Func keyedFactory) + return keyedFactory; + + var implementationType = descriptor.KeyedImplementationType ?? descriptor.ServiceType; + return (sp, _) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + #endif + } +} diff --git a/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorTags.verified.txt b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorTags.verified.txt new file mode 100644 index 0000000..01f4576 --- /dev/null +++ b/tests/Injectio.Tests/Snapshots/ServiceRegistrationDecoratorTests.GenerateDecoratorTags.verified.txt @@ -0,0 +1,231 @@ +// Injectio.g.cs +// +#nullable enable + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for discovered service registrations + /// + public static class DiscoveredServicesExtensions + { + /// + /// Adds discovered services from Test.Generator to the specified service collection + /// + /// The service collection. + /// The service registration tags to include. + /// The service collection + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestGenerator(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection serviceCollection, params string[]? tags) + { + var tagSet = new global::System.Collections.Generic.HashSet(tags ?? global::System.Linq.Enumerable.Empty()); + + if (tagSet.Count == 0 || tagSet.Intersect(new[] { "FrontEnd" }).Any()) + { + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton( + (serviceProvider) => global::Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService(serviceProvider) + ) + ); + + global::Microsoft.Extensions.DependencyInjection.Extensions.ServiceCollectionDescriptorExtensions.TryAdd( + serviceCollection, + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor.Singleton() + ); + + } + + if (tagSet.Count == 0 || tagSet.Intersect(new[] { "FrontEnd" }).Any()) + { + global::Injectio.Internal.InjectioDecorationExtensions.Decorate( + serviceCollection, + static (serviceProvider, inner) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(serviceProvider, inner) + ); + + } + + return serviceCollection; + } + } +} + + +// ========== + +// Injectio.Decoration.g.cs +// +#nullable enable + +namespace Injectio.Internal +{ + internal static class InjectioDecorationExtensions + { + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Decorate( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (IsKeyedDescriptor(descriptor)) + continue; + + var inner = CreateInnerFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => decoratorFactory(sp, (TService)inner(sp))!, + lifetime); + } + + return services; + } + + #if NET8_0_OR_GREATER + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateKeyed( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + object? serviceKey, + global::System.Func decoratorFactory) + where TService : class + { + var serviceType = typeof(TService); + bool anyKey = ReferenceEquals(serviceKey, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + if (descriptor.ServiceType != serviceType) + continue; + + if (!descriptor.IsKeyedService) + continue; + + if (!anyKey && !global::System.Collections.Generic.EqualityComparer.Default.Equals(descriptor.ServiceKey, serviceKey)) + continue; + + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => decoratorFactory(sp, key, (TService)innerKeyed(sp, key))!, + lifetime); + } + + return services; + } + #endif + + internal static global::Microsoft.Extensions.DependencyInjection.IServiceCollection DecorateOpenGeneric( + global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, + global::System.Type openServiceType, + global::System.Type openDecoratorType) + { + if (!openServiceType.IsGenericTypeDefinition) + openServiceType = openServiceType.GetGenericTypeDefinition(); + + if (!openDecoratorType.IsGenericTypeDefinition) + openDecoratorType = openDecoratorType.GetGenericTypeDefinition(); + + for (int i = 0; i < services.Count; i++) + { + var descriptor = services[i]; + var serviceType = descriptor.ServiceType; + + if (!serviceType.IsGenericType) + continue; + + // skip truly open-generic descriptors; MS.DI forbids factory on open service types + // and replacing the implementation type would cause recursive resolution. + if (serviceType.IsGenericTypeDefinition) + continue; + + if (serviceType.GetGenericTypeDefinition() != openServiceType) + continue; + + var typeArgs = serviceType.GetGenericArguments(); + global::System.Type closedDecoratorType; + try + { + closedDecoratorType = openDecoratorType.MakeGenericType(typeArgs); + } + catch (global::System.ArgumentException) + { + continue; + } + + #if NET8_0_OR_GREATER + if (descriptor.IsKeyedService) + { + var originalKey = descriptor.ServiceKey; + var innerKeyed = CreateInnerKeyedFactory(descriptor); + var lifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + originalKey, + (sp, key) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, innerKeyed(sp, key))!, + lifetime); + continue; + } + #endif + + var inner = CreateInnerFactory(descriptor); + var nonKeyedLifetime = descriptor.Lifetime; + + services[i] = new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor( + serviceType, + sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, closedDecoratorType, inner(sp))!, + nonKeyedLifetime); + } + + return services; + } + + private static bool IsKeyedDescriptor(global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + #if NET8_0_OR_GREATER + return descriptor.IsKeyedService; + #else + return false; + #endif + } + + private static global::System.Func CreateInnerFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.ImplementationInstance is object instance) + return _ => instance; + + if (descriptor.ImplementationFactory is global::System.Func factory) + return factory; + + var implementationType = descriptor.ImplementationType ?? descriptor.ServiceType; + return sp => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + + #if NET8_0_OR_GREATER + private static global::System.Func CreateInnerKeyedFactory( + global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor descriptor) + { + if (descriptor.KeyedImplementationInstance is object keyedInstance) + return (_, _) => keyedInstance; + + if (descriptor.KeyedImplementationFactory is global::System.Func keyedFactory) + return keyedFactory; + + var implementationType = descriptor.KeyedImplementationType ?? descriptor.ServiceType; + return (sp, _) => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance(sp, implementationType); + } + #endif + } +}