diff --git a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectHelper.cs b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectHelper.cs index edccd6af80..a34c9116e2 100644 --- a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectHelper.cs +++ b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectHelper.cs @@ -6,7 +6,7 @@ namespace Bit.BlazorUI.SourceGenerators.AutoInject; -public static class AutoInjectHelper +internal static class AutoInjectHelper { public static readonly string AutoInjectAttributeFullName = "Microsoft.Extensions.DependencyInjection.AutoInjectAttribute"; //typeof(AutoInjectAttribute).FullName; diff --git a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectMember.cs b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectMember.cs new file mode 100644 index 0000000000..c3dc31c4e2 --- /dev/null +++ b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectMember.cs @@ -0,0 +1,3 @@ +namespace Bit.BlazorUI.SourceGenerators.AutoInject; + +internal readonly record struct AutoInjectMember(string Name, string TypeDisplay, bool IsField, bool IsNullable); diff --git a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectNormalClassHandler.cs b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectNormalClassHandler.cs index 9eaa4adb3d..d54d22b25f 100644 --- a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectNormalClassHandler.cs +++ b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectNormalClassHandler.cs @@ -2,109 +2,89 @@ using System.Linq; using System.Text; using Bit.SourceGenerators; -using Microsoft.CodeAnalysis; namespace Bit.BlazorUI.SourceGenerators.AutoInject; -public static class AutoInjectNormalClassHandler +internal static class AutoInjectNormalClassHandler { - public static string? Generate(INamedTypeSymbol? attributeSymbol, INamedTypeSymbol? classSymbol, IReadOnlyCollection eligibleMembers) + public static string? Generate( + string classNamespace, + string classNameForCode, + string className, + IReadOnlyCollection directMembers, + IReadOnlyCollection baseMembers) { - if (classSymbol is null) - { - return null; - } - - if (AutoInjectHelper.IsContainingSymbolEqualToContainingNamespace(classSymbol) is false) - { - return null; - } - - string classNamespace = classSymbol.ContainingNamespace.ToDisplayString(); - - IReadOnlyCollection baseEligibleMembers = AutoInjectHelper.GetBaseClassEligibleMembers(classSymbol, attributeSymbol); - IReadOnlyCollection sortedMembers = eligibleMembers.OrderBy(o => o.Name).ToList(); + var sortedMembers = directMembers.OrderBy(o => o.Name).ToList(); string source = $@" namespace {classNamespace} {{ - public partial class {AutoInjectHelper.GenerateClassName(classSymbol)} + public partial class {classNameForCode} {{ - {GenerateConstructor(classSymbol, sortedMembers, baseEligibleMembers)} + {GenerateConstructor(className, sortedMembers, baseMembers)} }} }}"; return source; } - private static string GenerateConstructor(INamedTypeSymbol classSymbol, IReadOnlyCollection eligibleMembers, IReadOnlyCollection baseEligibleMembers) + private static string GenerateConstructor(string className, IReadOnlyCollection directMembers, IReadOnlyCollection baseMembers) { string generateConstructor = $@" [global::System.CodeDom.Compiler.GeneratedCode(""Bit.SourceGenerators"",""{BitSourceGeneratorUtil.GetPackageVersion()}"")] [global::System.Diagnostics.DebuggerNonUserCode] [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] -{"\t\t"}public {classSymbol.Name}({GenerateConstructorParameters(eligibleMembers, baseEligibleMembers)}){PassParametersToBaseClass(baseEligibleMembers)} +{"\t\t"}public {className}({GenerateConstructorParameters(directMembers, baseMembers)}){PassParametersToBaseClass(baseMembers)} {"\t\t"}{{ -{AssignedInjectedParametersToMembers(eligibleMembers)} +{AssignMembersFromParameters(directMembers)} {"\t\t"}}} "; return generateConstructor; } - private static string PassParametersToBaseClass(IReadOnlyCollection baseEligibleMembers) + private static string PassParametersToBaseClass(IReadOnlyCollection baseMembers) { - if (baseEligibleMembers.Any() is false) + if (baseMembers.Any() is false) return string.Empty; StringBuilder baseConstructor = new(); - baseConstructor.Append(": base("); - foreach (ISymbol symbol in baseEligibleMembers) + foreach (var member in baseMembers) { - baseConstructor.Append($@"{'\n'}{"\t\t\t\t\t\t"}autoInjected{AutoInjectHelper.FormatMemberName(symbol.Name)},"); + baseConstructor.Append($@"{'\n'}{"\t\t\t\t\t\t"}autoInjected{AutoInjectHelper.FormatMemberName(member.Name)},"); } baseConstructor.Length--; - baseConstructor.Append(')'); return baseConstructor.ToString(); } - private static string AssignedInjectedParametersToMembers(IReadOnlyCollection eligibleMembers) + private static string AssignMembersFromParameters(IReadOnlyCollection directMembers) { StringBuilder stringBuilder = new(); - foreach (ISymbol symbol in eligibleMembers) + foreach (var member in directMembers) { if (stringBuilder.Length > 0) { stringBuilder.Append('\n'); } stringBuilder.Append("\t\t\t") - .Append($@"{symbol.Name} = autoInjected{AutoInjectHelper.FormatMemberName(symbol.Name)};"); + .Append($@"{member.Name} = autoInjected{AutoInjectHelper.FormatMemberName(member.Name)};"); } return stringBuilder.ToString(); } - private static string GenerateConstructorParameters(IReadOnlyCollection eligibleMembers, IReadOnlyCollection baseEligibleMembers) + private static string GenerateConstructorParameters(IReadOnlyCollection directMembers, IReadOnlyCollection baseMembers) { StringBuilder stringBuilder = new(); - List members = new(eligibleMembers.Count + baseEligibleMembers.Count); + var allMembers = directMembers.Concat(baseMembers).OrderBy(o => o.Name).ToList(); - members.AddRange(eligibleMembers); - members.AddRange(baseEligibleMembers); - members = members.OrderBy(o => o.Name).ToList(); - - foreach (ISymbol member in members) + foreach (var member in allMembers) { - if (member is IFieldSymbol fieldSymbol) - stringBuilder.Append( - $@"{'\n'}{"\t\t\t"}{fieldSymbol.Type} autoInjected{AutoInjectHelper.FormatMemberName(fieldSymbol.Name)},"); - - if (member is IPropertySymbol propertySymbol) - stringBuilder.Append( - $@"{'\n'}{"\t\t\t"}{propertySymbol.Type} autoInjected{AutoInjectHelper.FormatMemberName(propertySymbol.Name)},"); + var nullValue = member.IsNullable ? " = null" : string.Empty; + stringBuilder.Append($@"{'\n'}{"}\t\t\t"}{member.TypeDisplay} autoInjected{AutoInjectHelper.FormatMemberName(member.Name)}{nullValue},"); } stringBuilder.Length--; @@ -112,3 +92,4 @@ private static string GenerateConstructorParameters(IReadOnlyCollection return stringBuilder.ToString(); } } + diff --git a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectRazorComponentHandler.cs b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectRazorComponentHandler.cs index 7fd5c9e74c..c91d9816c8 100644 --- a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectRazorComponentHandler.cs +++ b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectRazorComponentHandler.cs @@ -1,60 +1,43 @@ using System.Collections.Generic; -using System.Linq; using System.Text; using Bit.SourceGenerators; -using Microsoft.CodeAnalysis; namespace Bit.BlazorUI.SourceGenerators.AutoInject; -public static class AutoInjectRazorComponentHandler +internal static class AutoInjectRazorComponentHandler { - public static string? Generate(INamedTypeSymbol? classSymbol, IReadOnlyCollection eligibleMembers) + public static string? Generate( + string classNamespace, + string classNameForCode, + IReadOnlyCollection directMembers) { - if (classSymbol is null) - { - return null; - } - - if (AutoInjectHelper.IsContainingSymbolEqualToContainingNamespace(classSymbol) is false) - { - return null; - } - - string classNamespace = classSymbol.ContainingNamespace.ToDisplayString(); - - IReadOnlyCollection sortedMembers = eligibleMembers.OrderBy(o => o.Name).ToList(); - string source = $@" using Microsoft.AspNetCore.Components; using System.ComponentModel; namespace {classNamespace} {{ - public partial class {AutoInjectHelper.GenerateClassName(classSymbol)} + public partial class {classNameForCode} {{ - {GenerateInjectableProperties(sortedMembers)} + {GenerateInjectableProperties(directMembers)} }} }}"; return source; } - private static string GenerateInjectableProperties(IReadOnlyCollection eligibleMembers) + private static string GenerateInjectableProperties(IReadOnlyCollection members) { StringBuilder stringBuilder = new StringBuilder(); - foreach (ISymbol member in eligibleMembers) + foreach (var member in members) { - if (member is IFieldSymbol fieldSymbol) - stringBuilder.Append(GenerateProperty(fieldSymbol.Type, fieldSymbol.Name)); - - if (member is IPropertySymbol propertySymbol) - stringBuilder.Append(GenerateProperty(propertySymbol.Type, propertySymbol.Name)); + stringBuilder.Append(GenerateProperty(member.TypeDisplay, member.Name)); } return stringBuilder.ToString(); } - private static string GenerateProperty(ITypeSymbol @type, string name) + private static string GenerateProperty(string typeDisplay, string name) { return $@" [global::System.CodeDom.Compiler.GeneratedCode(""Bit.SourceGenerators"",""{BitSourceGeneratorUtil.GetPackageVersion()}"")] @@ -62,6 +45,7 @@ private static string GenerateProperty(ITypeSymbol @type, string name) [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] {"\t\t"}[Inject] {"\t\t"}[EditorBrowsable(EditorBrowsableState.Never)] -{"\t\t"}private {@type} ____{AutoInjectHelper.FormatMemberName(name)} {{ get => {name}; set => {name} = value; }}"; +{"\t\t"}private {typeDisplay} ____{AutoInjectHelper.FormatMemberName(name)} {{ get => {name}; set => {name} = value; }}"; } } + diff --git a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectSourceGenerator.cs b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectSourceGenerator.cs index 1f85ec78b2..a20459f3d4 100644 --- a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectSourceGenerator.cs +++ b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectSourceGenerator.cs @@ -1,8 +1,8 @@ -using System; -using System.Collections.Generic; -using System.IO; +using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; using System.Text; +using System.Threading; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -11,119 +11,323 @@ namespace Bit.BlazorUI.SourceGenerators.AutoInject; [Generator] -public class AutoInjectSourceGenerator : ISourceGenerator +public class AutoInjectSourceGenerator : IIncrementalGenerator { - private static int counter; - private static readonly DiagnosticDescriptor NonPartialClassError = new(id: "BITGEN001", - title: "The class needs to be partial", - messageFormat: "{0} is not partial. The AutoInject attribute needs to be used only in partial classes.", - category: "Bit.SourceGenerators", - DiagnosticSeverity.Error, - isEnabledByDefault: true); - - public void Initialize(GeneratorInitializationContext context) + private static readonly DiagnosticDescriptor NonPartialClassError = new( + id: "BITBLAZORUIGEN001", + title: "The class needs to be partial", + messageFormat: "{0} is not partial. The AutoInject attribute needs to be used only in partial classes.", + category: "Bit.SourceGenerators", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public void Initialize(IncrementalGeneratorInitializationContext context) { - context.RegisterForSyntaxNotifications(() => new AutoInjectSyntaxReceiver()); + // Provider 1: fields and properties directly annotated with [AutoInject] + var directMemberProvider = context.SyntaxProvider + .ForAttributeWithMetadataName( + AutoInjectHelper.AutoInjectAttributeFullName, + predicate: static (node, _) => node is FieldDeclarationSyntax or PropertyDeclarationSyntax, + transform: static (ctx, ct) => TransformDirectMember(ctx, ct)) + .Where(static e => e is not null) + .Select(static (e, _) => e!.Value); + + // Provider 2: classes whose base type uses [AutoInject] but they don't (including non-partial, to report diagnostic) + var derivedClassProvider = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (node, _) => node is ClassDeclarationSyntax, + transform: static (ctx, ct) => TransformDerivedClass(ctx, ct)) + .Where(static e => e is not null) + .Select(static (e, _) => e!.Value); + + var combined = directMemberProvider.Collect() + .Combine(derivedClassProvider.Collect()); + + context.RegisterSourceOutput(combined, static (spc, inputs) => Execute(spc, inputs.Left, inputs.Right)); } - public void Execute(GeneratorExecutionContext context) + // ── Data models ────────────────────────────────────────────────────────── + + private readonly record struct LocationInfo( + string FilePath, + int SpanStart, + int SpanLength, + int StartLine, + int StartChar, + int EndLine, + int EndChar); + + private readonly record struct DirectEntry( + string ContainingTypeFullName, + string ClassName, + string ClassNameForCode, + string ClassNamespace, + AutoInjectClassType ClassType, + bool IsPartial, + LocationInfo? ClassLocation, + AutoInjectMember Member, + // Base class members encoded as "F:name:type|P:name:type" for structural equality + string EncodedBaseMembers); + + private readonly record struct DerivedEntry( + string ContainingTypeFullName, + string ClassName, + string ClassNameForCode, + string ClassNamespace, + AutoInjectClassType ClassType, + bool IsPartial, + LocationInfo? ClassLocation, + string EncodedBaseMembers); + + // ── Transforms ─────────────────────────────────────────────────────────── + + private static DirectEntry? TransformDirectMember(GeneratorAttributeSyntaxContext ctx, CancellationToken ct) { - if (context.SyntaxContextReceiver is not AutoInjectSyntaxReceiver receiver) - return; + var symbol = ctx.TargetSymbol; + if (symbol is not (IFieldSymbol or IPropertySymbol)) return null; - INamedTypeSymbol? attributeSymbol = context.Compilation.GetTypeByMetadataName(AutoInjectHelper.AutoInjectAttributeFullName); + var containingType = symbol.ContainingType; + if (containingType is null) return null; - foreach (IGrouping group in receiver.EligibleMembers.GroupBy(f => f.ContainingType, SymbolEqualityComparer.Default)) + // Filter out nested types + if (!containingType.ContainingSymbol.Equals(containingType.ContainingNamespace, SymbolEqualityComparer.Default)) + return null; + + var attrSymbol = ctx.SemanticModel.Compilation.GetTypeByMetadataName(AutoInjectHelper.AutoInjectAttributeFullName); + + AutoInjectMember member; + if (symbol is IFieldSymbol f) + member = new AutoInjectMember(f.Name, f.Type.ToDisplayString(), IsField: true, IsNullable: f.NullableAnnotation is NullableAnnotation.Annotated); + else { - if (IsClassIsPartial(context, group.Key) is false) - return; + var p = (IPropertySymbol)symbol; + member = new AutoInjectMember(p.Name, p.Type.ToDisplayString(), IsField: false, IsNullable: p.NullableAnnotation is NullableAnnotation.Annotated); + } - string? partialClassSource = GenerateSource(attributeSymbol, group.Key, group.ToList()); + var baseMembers = attrSymbol is null + ? (IReadOnlyCollection)new List() + : AutoInjectHelper.GetBaseClassEligibleMembers(containingType, attrSymbol); - if (string.IsNullOrEmpty(partialClassSource) is false) + var isPartial = IsSymbolPartial(containingType); + var classType = IsRazorComponent(containingType) ? AutoInjectClassType.RazorComponent : AutoInjectClassType.NormalClass; + + LocationInfo? classLocation = null; + foreach (var syntaxRef in containingType.DeclaringSyntaxReferences) + { + if (syntaxRef.GetSyntax() is ClassDeclarationSyntax classDecl) { - context.AddSource($"{group.Key.Name}_{counter++}_autoInject.g.cs", SourceText.From(partialClassSource!, Encoding.UTF8)); + classLocation = GetLocationInfo(classDecl.Identifier); + break; } } - foreach (var @class in receiver.EligibleClassesWithBaseClassUsedAutoInject) + return new DirectEntry( + ContainingTypeFullName: containingType.ToDisplayString(), + ClassName: containingType.Name, + ClassNameForCode: AutoInjectHelper.GenerateClassName(containingType), + ClassNamespace: containingType.ContainingNamespace.ToDisplayString(), + ClassType: classType, + IsPartial: isPartial, + ClassLocation: classLocation, + Member: member, + EncodedBaseMembers: EncodeMembers(baseMembers)); + } + + private static DerivedEntry? TransformDerivedClass(GeneratorSyntaxContext ctx, CancellationToken ct) + { + var classDecl = (ClassDeclarationSyntax)ctx.Node; + var classSymbol = ctx.SemanticModel.GetDeclaredSymbol(classDecl, ct); + if (classSymbol is null) return null; + + if (classSymbol.BaseType is null) return null; + if (classSymbol.BaseType.ToDisplayString() == "System.Object") return null; + + // Filter out nested types + if (!classSymbol.ContainingSymbol.Equals(classSymbol.ContainingNamespace, SymbolEqualityComparer.Default)) + return null; + + var attrFqn = AutoInjectHelper.AutoInjectAttributeFullName; + + var attrSymbol = ctx.SemanticModel.Compilation.GetTypeByMetadataName(attrFqn); + if (attrSymbol is null) return null; + + var baseMembers = AutoInjectHelper.GetBaseClassEligibleMembers(classSymbol, attrSymbol); + if (baseMembers.Count == 0) return null; + + var isCurrentClassUseAutoInject = classSymbol + .GetMembers() + .Any(m => (m.Kind == SymbolKind.Field || m.Kind == SymbolKind.Property) && + m.GetAttributes().Any(a => a.AttributeClass?.ToDisplayString() == attrFqn)); + + // Let the direct-member provider handle classes that have their own [AutoInject] members + if (isCurrentClassUseAutoInject) return null; + var classType = IsRazorComponent(classSymbol) ? AutoInjectClassType.RazorComponent : AutoInjectClassType.NormalClass; + + return new DerivedEntry( + ContainingTypeFullName: classSymbol.ToDisplayString(), + ClassName: classSymbol.Name, + ClassNameForCode: AutoInjectHelper.GenerateClassName(classSymbol), + ClassNamespace: classSymbol.ContainingNamespace.ToDisplayString(), + ClassType: classType, + IsPartial: IsSymbolPartial(classSymbol), + ClassLocation: GetLocationInfo(classDecl.Identifier), + EncodedBaseMembers: EncodeMembers(baseMembers)); + } + + // ── Code generation ─────────────────────────────────────────────────────── + + private static void Execute( + SourceProductionContext spc, + ImmutableArray directEntries, + ImmutableArray derivedEntries) + { + // Group direct entries by class + var directGroups = directEntries + .GroupBy(e => e.ContainingTypeFullName) + .ToDictionary(g => g.Key, g => g.ToList()); + + // Emit one file per class that has direct [AutoInject] members + foreach (var kvp in directGroups) { - if (IsClassIsPartial(context, @class) is false) - return; + var fullName = kvp.Key; + var entries = kvp.Value; + var first = entries[0]; + + if (!first.IsPartial) + { + var loc = first.ClassLocation.HasValue ? ToLocation(first.ClassLocation.Value) : Location.None; + spc.ReportDiagnostic(Diagnostic.Create(NonPartialClassError, loc, first.ClassName)); + continue; + } - if (IsClassIsPartial(context, @class.BaseType!) is false) - return; + var directMembers = entries.Select(e => e.Member).OrderBy(m => m.Name).ToList(); + var baseMembers = DecodeMembers(first.EncodedBaseMembers); - string? partialClassSource = GenerateSource(attributeSymbol, @class, new List()); + string? source = first.ClassType == AutoInjectClassType.RazorComponent + ? AutoInjectRazorComponentHandler.Generate(first.ClassNamespace, first.ClassNameForCode, directMembers) + : AutoInjectNormalClassHandler.Generate(first.ClassNamespace, first.ClassNameForCode, first.ClassName, directMembers, baseMembers); - if (string.IsNullOrEmpty(partialClassSource) is false) + if (!string.IsNullOrEmpty(source)) { - context.AddSource($"{@class.Name}_{counter++}_autoInject.g.cs", SourceText.From(partialClassSource!, Encoding.UTF8)); + var hintName = $"{EscapeForHint(fullName)}_autoInject.g.cs"; + spc.AddSource(hintName, SourceText.From(source!, Encoding.UTF8)); } } - } - private static bool IsClassIsPartial(GeneratorExecutionContext context, INamedTypeSymbol @class) - { - var syntaxReferences = @class.DeclaringSyntaxReferences; - foreach (var refrence in syntaxReferences) + // Emit one file per derived class (pass-through constructor / empty inject list) + foreach (var entry in derivedEntries) { - var classDeclarationSyntax = (ClassDeclarationSyntax)refrence.GetSyntax(); - var classHasPartial = classDeclarationSyntax.Modifiers.Any(o => o.IsKind(SyntaxKind.PartialKeyword)); - if (classHasPartial is false) + // Skip if already handled by the direct provider + if (directGroups.ContainsKey(entry.ContainingTypeFullName)) continue; + + if (!entry.IsPartial) + { + var loc = entry.ClassLocation.HasValue ? ToLocation(entry.ClassLocation.Value) : Location.None; + spc.ReportDiagnostic(Diagnostic.Create(NonPartialClassError, loc, entry.ClassName)); + continue; + } + + var baseMembers = DecodeMembers(entry.EncodedBaseMembers); + var empty = new List(); + + string? source = entry.ClassType == AutoInjectClassType.RazorComponent + ? AutoInjectRazorComponentHandler.Generate(entry.ClassNamespace, entry.ClassNameForCode, empty) + : AutoInjectNormalClassHandler.Generate(entry.ClassNamespace, entry.ClassNameForCode, entry.ClassName, empty, baseMembers); + + if (!string.IsNullOrEmpty(source)) { - context.ReportDiagnostic(Diagnostic.Create(NonPartialClassError, classDeclarationSyntax.GetLocation(), @class.Name)); - return false; + var hintName = $"{EscapeForHint(entry.ContainingTypeFullName)}_autoInject.g.cs"; + spc.AddSource(hintName, SourceText.From(source!, Encoding.UTF8)); } } + } + + // ── Helpers ─────────────────────────────────────────────────────────────── - return true; + private static bool IsRazorComponent(INamedTypeSymbol @class) + { + // Use interface check only — avoids File.Exists() I/O which is forbidden in incremental transforms + return @class.AllInterfaces.Any(o => o.ToDisplayString() == "Microsoft.AspNetCore.Components.IComponent"); } - private static string? GenerateSource(INamedTypeSymbol? attributeSymbol, INamedTypeSymbol? classSymbol, IReadOnlyCollection eligibleMembers) + private static bool IsSymbolPartial(INamedTypeSymbol classSymbol) { - AutoInjectClassType env = FigureOutTypeOfEnvironment(classSymbol); - return env switch + foreach (var syntaxRef in classSymbol.DeclaringSyntaxReferences) { - AutoInjectClassType.NormalClass => AutoInjectNormalClassHandler.Generate(attributeSymbol, classSymbol, eligibleMembers), - AutoInjectClassType.RazorComponent => AutoInjectRazorComponentHandler.Generate(classSymbol, eligibleMembers), - _ => string.Empty - }; + if (syntaxRef.GetSyntax() is ClassDeclarationSyntax cls && + cls.Modifiers.Any(m => m.IsKind(SyntaxKind.PartialKeyword))) + return true; + } + return false; } - private static AutoInjectClassType FigureOutTypeOfEnvironment(INamedTypeSymbol? @class) + private static string EncodeMembers(IEnumerable members) { - if (@class is null) - throw new ArgumentNullException(nameof(@class)); - - if (IsClassIsRazorComponent(@class)) - return AutoInjectClassType.RazorComponent; - else - return AutoInjectClassType.NormalClass; + var sb = new StringBuilder(); + foreach (var m in members) + { + if (sb.Length > 0) + { + sb.Append('|'); + } + + if (m is IFieldSymbol f) + sb.Append('F').Append('\t').Append(f.Name).Append('\t').Append(f.Type.ToDisplayString()).Append('\t').Append(f.NullableAnnotation is NullableAnnotation.Annotated ? '1' : '0'); + else if (m is IPropertySymbol p) + sb.Append('P').Append('\t').Append(p.Name).Append('\t').Append(p.Type.ToDisplayString()).Append('\t').Append(p.NullableAnnotation is NullableAnnotation.Annotated ? '1' : '0'); + } + return sb.ToString(); } - private static bool IsClassIsRazorComponent(INamedTypeSymbol @class) + private static List DecodeMembers(string encoded) { - bool isInheritIComponent = @class.AllInterfaces.Any(o => o.ToDisplayString() == "Microsoft.AspNetCore.Components.IComponent"); + var result = new List(); + if (string.IsNullOrEmpty(encoded)) return result; - if (isInheritIComponent) - return true; + foreach (var part in encoded.Split('|')) + { + // format: "F\tname\ttype\tnullable" (4 tab-separated fields) + var fields = part.Split('\t'); + if (fields.Length < 4) continue; + var kind = fields[0][0]; + var name = fields[1]; + var typeDisplay = fields[2]; + var isNullable = fields[3] == "1"; + result.Add(new AutoInjectMember(name, typeDisplay, IsField: kind == 'F', IsNullable: isNullable)); + } - var classFilePaths = @class.Locations - .Where(o => o.SourceTree is not null) - .Select(o => o.SourceTree?.FilePath) - .ToList(); + return result; + } - string razorFileName = $"{@class.Name}.razor"; + private static string EscapeForHint(string fullyQualifiedName) + { + return fullyQualifiedName.Replace('<', '[').Replace('>', ']').Replace(' ', '_'); + } - foreach (var path in classFilePaths) - { - string directoryPath = Path.GetDirectoryName(path) ?? string.Empty; - string filePath = Path.Combine(directoryPath, razorFileName); - if (File.Exists(filePath)) - return true; - } + private static LocationInfo? GetLocationInfo(SyntaxToken token) + { + var location = token.GetLocation(); + + if (location.SourceTree is null) return null; - return false; + var lineSpan = location.GetLineSpan(); + + return new LocationInfo( + FilePath: location.SourceTree.FilePath, + SpanStart: location.SourceSpan.Start, + SpanLength: location.SourceSpan.Length, + StartLine: lineSpan.StartLinePosition.Line, + StartChar: lineSpan.StartLinePosition.Character, + EndLine: lineSpan.EndLinePosition.Line, + EndChar: lineSpan.EndLinePosition.Character); } + + private static Location ToLocation(LocationInfo info) + => Location.Create( + info.FilePath, + new TextSpan(info.SpanStart, info.SpanLength), + new LinePositionSpan( + new LinePosition(info.StartLine, info.StartChar), + new LinePosition(info.EndLine, info.EndChar))); } + diff --git a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectSyntaxReceiver.cs b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectSyntaxReceiver.cs deleted file mode 100644 index b69e5017a6..0000000000 --- a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/AutoInject/AutoInjectSyntaxReceiver.cs +++ /dev/null @@ -1,96 +0,0 @@ -using System.Collections.ObjectModel; -using System.Linq; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; - -namespace Bit.BlazorUI.SourceGenerators.AutoInject; - -public class AutoInjectSyntaxReceiver : ISyntaxContextReceiver -{ - public Collection EligibleMembers { get; } = new(); - public Collection EligibleClassesWithBaseClassUsedAutoInject { get; } = new(); - - public void OnVisitSyntaxNode(GeneratorSyntaxContext context) - { - MarkEligibleFields(context); - MarkEligibleProperties(context); - MarkEligibleClasses(context); - } - - private void MarkEligibleClasses(GeneratorSyntaxContext context) - { - if (context.Node is not ClassDeclarationSyntax classDeclarationSyntax) - return; - - var classSymbol = context.SemanticModel.GetDeclaredSymbol(classDeclarationSyntax); - - if (classSymbol == null) - return; - - if (classSymbol.BaseType == null) - return; - - if (classSymbol.BaseType.ToDisplayString() == "System.Object") - return; - - var isBaseTypeUseAutoInject = classSymbol.BaseType - .GetMembers() - .Any(m => - (m.Kind == SymbolKind.Field || m.Kind == SymbolKind.Property) && - m.GetAttributes() - .Any(a => a.AttributeClass != null && - a.AttributeClass.ToDisplayString() == AutoInjectHelper.AutoInjectAttributeFullName)); - - var isCurrentClassUseAutoInject = classSymbol - .GetMembers() - .Any(m => - (m.Kind == SymbolKind.Field || m.Kind == SymbolKind.Property) && - m.GetAttributes() - .Any(a => a.AttributeClass != null && - a.AttributeClass.ToDisplayString() == AutoInjectHelper.AutoInjectAttributeFullName)); - - if (isBaseTypeUseAutoInject && (isCurrentClassUseAutoInject is false)) - EligibleClassesWithBaseClassUsedAutoInject.Add(classSymbol); - } - - private void MarkEligibleFields(GeneratorSyntaxContext context) - { - if (context.Node is not FieldDeclarationSyntax fieldDeclarationSyntax || fieldDeclarationSyntax.AttributeLists.Any() is false) - return; - - if (fieldDeclarationSyntax.Parent is not ClassDeclarationSyntax classDeclarationSyntax || classDeclarationSyntax is null) - return; - - foreach (VariableDeclaratorSyntax variable in fieldDeclarationSyntax.Declaration.Variables) - { - var fieldSymbol = ModelExtensions.GetDeclaredSymbol(context.SemanticModel, variable) as IFieldSymbol; - if (fieldSymbol is not null && - fieldSymbol.GetAttributes() - .Any(ad => ad.AttributeClass is not null && - ad.AttributeClass.ToDisplayString() == AutoInjectHelper.AutoInjectAttributeFullName)) - { - EligibleMembers.Add(fieldSymbol); - } - } - } - - private void MarkEligibleProperties(GeneratorSyntaxContext context) - { - if (context.Node is not PropertyDeclarationSyntax propertyDeclarationSyntax || propertyDeclarationSyntax.AttributeLists.Count <= 0) - return; - - if (propertyDeclarationSyntax.Parent is not ClassDeclarationSyntax classDeclarationSyntax || classDeclarationSyntax is null) - return; - - var propertySymbol = context.SemanticModel.GetDeclaredSymbol(propertyDeclarationSyntax); - - if (propertySymbol is null) - return; - - if (propertySymbol.GetAttributes().Any(ad => ad.AttributeClass is not null && ad.AttributeClass.ToDisplayString() == AutoInjectHelper.AutoInjectAttributeFullName)) - { - EligibleMembers.Add(propertySymbol); - } - } -} diff --git a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/Component/BlazorParameter.cs b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/Component/BlazorParameter.cs index ee6c8ee916..9b07ac7a49 100644 --- a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/Component/BlazorParameter.cs +++ b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/Component/BlazorParameter.cs @@ -1,16 +1,16 @@ -using Microsoft.CodeAnalysis; +namespace Bit.BlazorUI.SourceGenerators.Component; -namespace Bit.BlazorUI.SourceGenerators.Component; - -public class BlazorParameter(IPropertySymbol propertySymbol, bool resetClassBuilder, bool resetStyleBuilder, bool isTwoWayBound) -{ - public IPropertySymbol PropertySymbol { get; set; } = propertySymbol; - - public bool IsTwoWayBound { get; set; } = isTwoWayBound; - - public bool ResetClassBuilder { get; set; } = resetClassBuilder; - public bool ResetStyleBuilder { get; set; } = resetStyleBuilder; - - public string? CallOnSetMethodName { get; set; } - public string? CallOnSetAsyncMethodName { get; set; } -} +internal readonly record struct BlazorParameter( + string ContainingTypeFullName, + string ClassName, + string ClassNameForCode, + string ClassNamespace, + bool IsBaseTypeComponentBase, + bool InheritsFromBitComponentBase, + string PropertyName, + string PropertyType, + bool ResetClassBuilder, + bool ResetStyleBuilder, + bool IsTwoWayBound, + string? CallOnSetMethodName, + string? CallOnSetAsyncMethodName); diff --git a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/Component/ComponentSourceGenerator.cs b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/Component/ComponentSourceGenerator.cs index a9235a7c53..aca17c628d 100644 --- a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/Component/ComponentSourceGenerator.cs +++ b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/Component/ComponentSourceGenerator.cs @@ -1,41 +1,108 @@ -using System.Linq; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; using System.Text; -using System.Collections.Generic; +using System.Threading; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; namespace Bit.BlazorUI.SourceGenerators.Component; [Generator] -public class ComponentSourceGenerator : ISourceGenerator +public class ComponentSourceGenerator : IIncrementalGenerator { - public void Initialize(GeneratorInitializationContext context) + public void Initialize(IncrementalGeneratorInitializationContext context) { - context.RegisterForSyntaxNotifications(() => new ComponentSyntaxContextReceiver()); + var parameterProvider = context.SyntaxProvider + .ForAttributeWithMetadataName( + "Microsoft.AspNetCore.Components.ParameterAttribute", + predicate: static (node, _) => IsPartialClassProperty(node), + transform: static (ctx, ct) => ExtractBlazorParameter(ctx, ct)) + .Where(static p => p is not null) + .Select(static (p, _) => p!.Value); + + var cascadingProvider = context.SyntaxProvider + .ForAttributeWithMetadataName( + "Microsoft.AspNetCore.Components.CascadingParameterAttribute", + predicate: static (node, _) => IsPartialClassProperty(node), + transform: static (ctx, ct) => ExtractBlazorParameter(ctx, ct)) + .Where(static p => p is not null) + .Select(static (p, _) => p!.Value); + + var combined = parameterProvider.Collect() + .Combine(cascadingProvider.Collect()) + .Select(static (pair, _) => pair.Left.AddRange(pair.Right)); + + context.RegisterSourceOutput(combined, static (spc, parameters) => Execute(spc, parameters)); } - public void Execute(GeneratorExecutionContext context) + private static bool IsPartialClassProperty(SyntaxNode node) { - if (context.SyntaxContextReceiver is not ComponentSyntaxContextReceiver receiver) return; + return node is PropertyDeclarationSyntax prop && + prop.Parent is ClassDeclarationSyntax cls && + cls.Modifiers.Any(m => m.IsKind(SyntaxKind.PartialKeyword)); + } - foreach (var parametersGroup in receiver.Parameters.GroupBy(p => p.PropertySymbol.ContainingType, SymbolEqualityComparer.Default)) - { - var parameters = parametersGroup.ToList(); + private static BlazorParameter? ExtractBlazorParameter(GeneratorAttributeSyntaxContext ctx, CancellationToken ct) + { + if (ctx.TargetSymbol is not IPropertySymbol prop) return null; + + var containingType = prop.ContainingType; + if (containingType is null) return null; + + if (containingType.GetMembers().Any(m => m.Name == "SetParametersAsync")) return null; + + var attrs = prop.GetAttributes(); + var resetClassBuilder = attrs.Any(a => a.AttributeClass?.ToDisplayString() == "Bit.BlazorUI.ResetClassBuilderAttribute"); + var resetStyleBuilder = attrs.Any(a => a.AttributeClass?.ToDisplayString() == "Bit.BlazorUI.ResetStyleBuilderAttribute"); + var isTwoWayBound = attrs.Any(a => a.AttributeClass?.ToDisplayString() == "Bit.BlazorUI.TwoWayBoundAttribute"); + + var callOnSetAttr = attrs.SingleOrDefault(a => a.AttributeClass?.ToDisplayString() == "Bit.BlazorUI.CallOnSetAttribute"); + var callOnSetName = callOnSetAttr?.ConstructorArguments.FirstOrDefault().Value as string; - if (parametersGroup.Key == null) continue; + var callOnSetAsyncAttr = attrs.SingleOrDefault(a => a.AttributeClass?.ToDisplayString() == "Bit.BlazorUI.CallOnSetAsyncAttribute"); + var callOnSetAsyncName = callOnSetAsyncAttr?.ConstructorArguments.FirstOrDefault().Value as string; - string classSource = GeneratePartialClass((INamedTypeSymbol)parametersGroup.Key, parameters); - context.AddSource($"{parametersGroup.Key.Name}_SetParametersAsync.AutoGenerated.cs", SourceText.From(classSource, Encoding.UTF8)); + var classNameForCode = BuildClassNameForCode(containingType); + var isBaseTypeComponentBase = containingType.BaseType?.ToDisplayString() == "Microsoft.AspNetCore.Components.ComponentBase"; + var inheritsFromBit = InheritsFromBitComponentBase(containingType); + + return new BlazorParameter( + ContainingTypeFullName: containingType.ToDisplayString(), + ClassName: containingType.Name, + ClassNameForCode: classNameForCode, + ClassNamespace: containingType.ContainingNamespace.ToDisplayString(), + IsBaseTypeComponentBase: isBaseTypeComponentBase, + InheritsFromBitComponentBase: inheritsFromBit, + PropertyName: prop.Name, + PropertyType: prop.Type.ToDisplayString(), + ResetClassBuilder: resetClassBuilder, + ResetStyleBuilder: resetStyleBuilder, + IsTwoWayBound: isTwoWayBound, + CallOnSetMethodName: callOnSetName, + CallOnSetAsyncMethodName: callOnSetAsyncName); + } + + private static void Execute(SourceProductionContext spc, ImmutableArray parameters) + { + foreach (var group in parameters.GroupBy(p => p.ContainingTypeFullName)) + { + var list = group.ToList(); + var first = list[0]; + string source = GeneratePartialClass(first, list); + spc.AddSource($"{first.ClassName}_SetParametersAsync.AutoGenerated.cs", SourceText.From(source, Encoding.UTF8)); } } - private static string GeneratePartialClass(INamedTypeSymbol classSymbol, List parameters) + private static string GeneratePartialClass(BlazorParameter classInfo, List parameters) { - var namespaceName = classSymbol.ContainingNamespace.ToDisplayString(); - var className = GetClassName(classSymbol); + var namespaceName = classInfo.ClassNamespace; + var className = classInfo.ClassNameForCode; var twoWayParameters = parameters.Where(p => p.IsTwoWayBound).ToArray(); - var isBaseTypeComponentBase = classSymbol.BaseType?.ToDisplayString() == "Microsoft.AspNetCore.Components.ComponentBase"; - var doesSupporteParametersViewCache = InheritsFromBitComponentBase(classSymbol); + var isBaseTypeComponentBase = classInfo.IsBaseTypeComponentBase; + var doesSupporteParametersViewCache = classInfo.InheritsFromBitComponentBase; StringBuilder builder = new StringBuilder($@"using System; using System.Threading.Tasks; @@ -52,9 +119,8 @@ public partial class {className} builder.AppendLine(""); foreach (var par in twoWayParameters) { - var sym = par.PropertySymbol; - builder.AppendLine($" private bool {sym.Name}HasBeenSet;"); - builder.AppendLine($" [Parameter] public EventCallback<{sym.Type.ToDisplayString()}> {sym.Name}Changed {{ get; set; }}"); + builder.AppendLine($" private bool {par.PropertyName}HasBeenSet;"); + builder.AppendLine($" [Parameter] public EventCallback<{par.PropertyType}> {par.PropertyName}Changed {{ get; set; }}"); } if (twoWayParameters.Length > 0) builder.AppendLine(""); builder.AppendLine($@" [global::System.Diagnostics.DebuggerNonUserCode] @@ -64,7 +130,7 @@ public override async Task SetParametersAsync(ParameterView parameters) builder.AppendLine($" __assignedParameters.Clear();"); foreach (var par in twoWayParameters) { - builder.AppendLine($" {par.PropertySymbol.Name}HasBeenSet = false;"); + builder.AppendLine($" {par.PropertyName}HasBeenSet = false;"); } if (doesSupporteParametersViewCache) { @@ -80,10 +146,9 @@ public override async Task SetParametersAsync(ParameterView parameters) builder.AppendLine(" {"); foreach (var par in parameters) { - var sym = par.PropertySymbol; - var paramName = sym.Name; + var paramName = par.PropertyName; var varName = $"@{paramName.ToLower()}"; - var paramType = sym.Type.ToDisplayString(); + var paramType = par.PropertyType; builder.AppendLine($" case nameof({paramName}):"); builder.AppendLine($" __assignedParameters.Add(nameof({paramName}));"); if (par.IsTwoWayBound) @@ -116,11 +181,11 @@ public override async Task SetParametersAsync(ParameterView parameters) builder.AppendLine(" break;"); if (par.IsTwoWayBound) { - paramName = $"{paramName}Changed"; - varName = $"@{paramName.ToLower()}"; - builder.AppendLine($" case nameof({paramName}):"); - builder.AppendLine($" var {varName} = parameter.Value is null ? default! : (EventCallback<{sym.Type.ToDisplayString()}>)parameter.Value;"); - builder.AppendLine($" {paramName} = {varName};"); + var changedName = $"{paramName}Changed"; + var changedVarName = $"@{changedName.ToLower()}"; + builder.AppendLine($" case nameof({changedName}):"); + builder.AppendLine($" var {changedVarName} = parameter.Value is null ? default! : (EventCallback<{paramType}>)parameter.Value;"); + builder.AppendLine($" {changedName} = {changedVarName};"); builder.AppendLine(" parametersDictionary.Remove(parameter.Key);"); builder.AppendLine(" break;"); } @@ -156,8 +221,8 @@ public bool HasNotBeenSet(string name) if (twoWayParameters.Length > 0) builder.AppendLine(""); foreach (var par in twoWayParameters) { - var paramName = par.PropertySymbol.Name; - var paramType = par.PropertySymbol.Type.ToDisplayString(); + var paramName = par.PropertyName; + var paramType = par.PropertyType; builder.AppendLine($@" [global::System.Diagnostics.DebuggerNonUserCode] [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] public async Task Assign{paramName}({paramType} value) @@ -194,31 +259,21 @@ public bool HasNotBeenSet(string name) return builder.ToString(); } - private static string GetClassName(INamedTypeSymbol classSymbol) + private static string BuildClassNameForCode(INamedTypeSymbol classSymbol) { - StringBuilder sbName = new StringBuilder(classSymbol.Name); - if (classSymbol.IsGenericType) { - sbName.Append('<'); - sbName.Append(string.Join(", ", classSymbol.TypeArguments.Select(s => s.Name))); - sbName.Append('>'); + var typeArgs = string.Join(", ", classSymbol.TypeArguments.Select(s => s.Name)); + return $"{classSymbol.Name}<{typeArgs}>"; } - - return sbName.ToString(); + return classSymbol.Name; } private static bool InheritsFromBitComponentBase(INamedTypeSymbol? typeSymbol) { - if (typeSymbol is null) - return false; - - if (typeSymbol.TypeKind is not TypeKind.Class) - return false; - - if (typeSymbol.Name == "BitComponentBase") - return true; - + if (typeSymbol is null) return false; + if (typeSymbol.TypeKind is not TypeKind.Class) return false; + if (typeSymbol.Name == "BitComponentBase") return true; return InheritsFromBitComponentBase(typeSymbol.BaseType); } } diff --git a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/Component/ComponentSyntaxContextReceiver.cs b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/Component/ComponentSyntaxContextReceiver.cs deleted file mode 100644 index ff6164ec6e..0000000000 --- a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/Component/ComponentSyntaxContextReceiver.cs +++ /dev/null @@ -1,57 +0,0 @@ -using System.Linq; -using System.Collections.Generic; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; - -namespace Bit.BlazorUI.SourceGenerators.Component; - -public class ComponentSyntaxContextReceiver : ISyntaxContextReceiver -{ - public IList Parameters { get; } = []; - - public void OnVisitSyntaxNode(GeneratorSyntaxContext context) - { - if (context.Node is not PropertyDeclarationSyntax propertyDeclarationSyntax || !propertyDeclarationSyntax.AttributeLists.Any()) return; - - var parent = propertyDeclarationSyntax.Parent; - - if (parent is null || parent.IsKind(SyntaxKind.ClassDeclaration) is false) return; - - var classDeclarationSyntax = (ClassDeclarationSyntax?)parent; - - if (classDeclarationSyntax?.Modifiers.Any(k => k.IsKind(SyntaxKind.PartialKeyword)) is false) return; - - var propertySymbol = context.SemanticModel.GetDeclaredSymbol(propertyDeclarationSyntax); - - if (propertySymbol is null) return; - - var type = propertySymbol.ContainingType; - - if (type == null) return; - - if (type.GetMembers().Any(m => m.Name == "SetParametersAsync")) return; - - var attributes = propertySymbol.GetAttributes(); - - if (attributes.Any(ad => ad.AttributeClass?.ToDisplayString() == "Microsoft.AspNetCore.Components.ParameterAttribute" || - ad.AttributeClass?.ToDisplayString() == "Microsoft.AspNetCore.Components.CascadingParameterAttribute")) - { - var resetClassBuilder = attributes.Any(a => a.AttributeClass?.ToDisplayString() == "Bit.BlazorUI.ResetClassBuilderAttribute"); - var resetStyleBuilder = attributes.Any(a => a.AttributeClass?.ToDisplayString() == "Bit.BlazorUI.ResetStyleBuilderAttribute"); - var isTwoWayBound = attributes.Any(a => a.AttributeClass?.ToDisplayString() == "Bit.BlazorUI.TwoWayBoundAttribute"); - - var parameter = new BlazorParameter(propertySymbol, resetClassBuilder, resetStyleBuilder, isTwoWayBound); - - var callOnSetAttribute = attributes.SingleOrDefault(a => a.AttributeClass?.ToDisplayString() == "Bit.BlazorUI.CallOnSetAttribute"); - var callOnSetName = callOnSetAttribute?.ConstructorArguments.FirstOrDefault().Value as string; - parameter.CallOnSetMethodName = callOnSetName; - - var callOnSetAsyncAttribute = attributes.SingleOrDefault(a => a.AttributeClass?.ToDisplayString() == "Bit.BlazorUI.CallOnSetAsyncAttribute"); - var callOnSetAsyncName = callOnSetAsyncAttribute?.ConstructorArguments.FirstOrDefault().Value as string; - parameter.CallOnSetAsyncMethodName = callOnSetAsyncName; - - Parameters.Add(parameter); - } - } -} diff --git a/src/BlazorUI/Bit.BlazorUI.SourceGenerators/IsExternalInit.cs b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/IsExternalInit.cs new file mode 100644 index 0000000000..b22297d3ac --- /dev/null +++ b/src/BlazorUI/Bit.BlazorUI.SourceGenerators/IsExternalInit.cs @@ -0,0 +1,5 @@ +// Polyfill required for record types and init-only setters when targeting netstandard2.0. +namespace System.Runtime.CompilerServices +{ + internal static class IsExternalInit { } +} diff --git a/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectHelper.cs b/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectHelper.cs index 588cffeb1e..b00ad47fd5 100644 --- a/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectHelper.cs +++ b/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectHelper.cs @@ -6,7 +6,7 @@ namespace Bit.SourceGenerators; -public static class AutoInjectHelper +internal static class AutoInjectHelper { public static readonly string AutoInjectAttributeFullName = "Microsoft.Extensions.DependencyInjection.AutoInjectAttribute"; //typeof(AutoInjectAttribute).FullName; diff --git a/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectMember.cs b/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectMember.cs new file mode 100644 index 0000000000..d53ee7bcdd --- /dev/null +++ b/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectMember.cs @@ -0,0 +1,3 @@ +namespace Bit.SourceGenerators; + +internal readonly record struct AutoInjectMember(string Name, string TypeDisplay, bool IsField, bool IsNullable); diff --git a/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectNormalClassHandler.cs b/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectNormalClassHandler.cs index d6c01d9396..430f10db99 100644 --- a/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectNormalClassHandler.cs +++ b/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectNormalClassHandler.cs @@ -1,119 +1,90 @@ using System.Collections.Generic; using System.Linq; using System.Text; -using Microsoft.CodeAnalysis; namespace Bit.SourceGenerators; -public static class AutoInjectNormalClassHandler +internal static class AutoInjectNormalClassHandler { - public static string? Generate(INamedTypeSymbol? attributeSymbol, INamedTypeSymbol? classSymbol, IReadOnlyCollection eligibleMembers) + public static string? Generate( + string classNamespace, + string classNameForCode, + string className, + IReadOnlyCollection directMembers, + IReadOnlyCollection baseMembers) { - if (classSymbol is null) - { - return null; - } - - if (AutoInjectHelper.IsContainingSymbolEqualToContainingNamespace(classSymbol) is false) - { - return null; - } - - string classNamespace = classSymbol.ContainingNamespace.ToDisplayString(); - - IReadOnlyCollection baseEligibleMembers = AutoInjectHelper.GetBaseClassEligibleMembers(classSymbol, attributeSymbol); - IReadOnlyCollection sortedMembers = eligibleMembers.OrderBy(o => o.Name).ToList(); + var sortedMembers = directMembers.OrderBy(o => o.Name).ToList(); string source = $@" namespace {classNamespace} {{ - public partial class {AutoInjectHelper.GenerateClassName(classSymbol)} + public partial class {classNameForCode} {{ - {GenerateConstructor(classSymbol, sortedMembers, baseEligibleMembers)} + {GenerateConstructor(className, sortedMembers, baseMembers)} }} }}"; return source; } - private static string GenerateConstructor(INamedTypeSymbol classSymbol, IReadOnlyCollection eligibleMembers, IReadOnlyCollection baseEligibleMembers) + private static string GenerateConstructor(string className, IReadOnlyCollection directMembers, IReadOnlyCollection baseMembers) { string generateConstructor = $@" [global::System.CodeDom.Compiler.GeneratedCode(""Bit.SourceGenerators"",""{BitSourceGeneratorUtil.GetPackageVersion()}"")] [global::System.Diagnostics.DebuggerNonUserCode] [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] -{"\t\t"}public {classSymbol.Name}({GenerateConstructorParameters(eligibleMembers, baseEligibleMembers)}){PassParametersToBaseClass(baseEligibleMembers)} +{"\t\t"}public {className}({GenerateConstructorParameters(directMembers, baseMembers)}){PassParametersToBaseClass(baseMembers)} {"\t\t"}{{ -{AssignedInjectedParametersToMembers(eligibleMembers)} +{AssignMembersFromParameters(directMembers)} {"\t\t"}}} "; return generateConstructor; } - private static string PassParametersToBaseClass(IReadOnlyCollection baseEligibleMembers) + private static string PassParametersToBaseClass(IReadOnlyCollection baseMembers) { - if (baseEligibleMembers.Any() is false) + if (baseMembers.Any() is false) return string.Empty; StringBuilder baseConstructor = new(); - baseConstructor.Append(": base("); - foreach (ISymbol symbol in baseEligibleMembers) + foreach (var member in baseMembers) { - baseConstructor.Append($@"{'\n'}{"\t\t\t\t\t\t"}autoInjected{AutoInjectHelper.FormatMemberName(symbol.Name)},"); + baseConstructor.Append($@"{'\n'}{"\t\t\t\t\t\t"}autoInjected{AutoInjectHelper.FormatMemberName(member.Name)},"); } baseConstructor.Length--; - baseConstructor.Append(')'); return baseConstructor.ToString(); } - private static string AssignedInjectedParametersToMembers(IReadOnlyCollection eligibleMembers) + private static string AssignMembersFromParameters(IReadOnlyCollection directMembers) { StringBuilder stringBuilder = new(); - foreach (ISymbol symbol in eligibleMembers) + foreach (var member in directMembers) { if (stringBuilder.Length > 0) { stringBuilder.Append('\n'); } stringBuilder.Append("\t\t\t") - .Append($@"{symbol.Name} = autoInjected{AutoInjectHelper.FormatMemberName(symbol.Name)};"); + .Append($@"{member.Name} = autoInjected{AutoInjectHelper.FormatMemberName(member.Name)};"); } return stringBuilder.ToString(); } - private static string GenerateConstructorParameters(IReadOnlyCollection eligibleMembers, IReadOnlyCollection baseEligibleMembers) + private static string GenerateConstructorParameters(IReadOnlyCollection directMembers, IReadOnlyCollection baseMembers) { StringBuilder stringBuilder = new(); - List members = new(eligibleMembers.Count + baseEligibleMembers.Count); + var allMembers = directMembers.Concat(baseMembers).OrderBy(o => o.Name).ToList(); - members.AddRange(eligibleMembers); - members.AddRange(baseEligibleMembers); - members = members.OrderBy(o => o.Name).ToList(); - - foreach (ISymbol member in members) + foreach (var member in allMembers) { - if (member is IFieldSymbol fieldSymbol) - { - var isNullable = fieldSymbol.NullableAnnotation is NullableAnnotation.Annotated; - var nullValue = isNullable ? " = null" : string.Empty; - - stringBuilder.Append( - $@"{'\n'}{"\t\t\t"}{fieldSymbol.Type} autoInjected{AutoInjectHelper.FormatMemberName(fieldSymbol.Name)} {nullValue},"); - } - - if (member is IPropertySymbol propertySymbol) - { - var isNullable = propertySymbol.NullableAnnotation is NullableAnnotation.Annotated; - var nullValue = isNullable ? " = null" : string.Empty; - - stringBuilder.Append( - $@"{'\n'}{"\t\t\t"}{propertySymbol.Type} autoInjected{AutoInjectHelper.FormatMemberName(propertySymbol.Name)} {nullValue},"); - } + var nullValue = member.IsNullable ? " = null" : string.Empty; + stringBuilder.Append( + $@"{'\n'}{"\t\t\t"}{member.TypeDisplay} autoInjected{AutoInjectHelper.FormatMemberName(member.Name)} {nullValue},"); } stringBuilder.Length--; diff --git a/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectRazorComponentHandler.cs b/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectRazorComponentHandler.cs index fd1c072cd2..b570f050e0 100644 --- a/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectRazorComponentHandler.cs +++ b/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectRazorComponentHandler.cs @@ -1,59 +1,42 @@ using System.Collections.Generic; -using System.Linq; using System.Text; -using Microsoft.CodeAnalysis; namespace Bit.SourceGenerators; -public static class AutoInjectRazorComponentHandler +internal static class AutoInjectRazorComponentHandler { - public static string? Generate(INamedTypeSymbol? classSymbol, IReadOnlyCollection eligibleMembers) + public static string? Generate( + string classNamespace, + string classNameForCode, + IReadOnlyCollection directMembers) { - if (classSymbol is null) - { - return null; - } - - if (AutoInjectHelper.IsContainingSymbolEqualToContainingNamespace(classSymbol) is false) - { - return null; - } - - string classNamespace = classSymbol.ContainingNamespace.ToDisplayString(); - - IReadOnlyCollection sortedMembers = eligibleMembers.OrderBy(o => o.Name).ToList(); - string source = $@" using Microsoft.AspNetCore.Components; using System.ComponentModel; namespace {classNamespace} {{ - public partial class {AutoInjectHelper.GenerateClassName(classSymbol)} + public partial class {classNameForCode} {{ - {GenerateInjectableProperties(sortedMembers)} + {GenerateInjectableProperties(directMembers)} }} }}"; return source; } - private static string GenerateInjectableProperties(IReadOnlyCollection eligibleMembers) + private static string GenerateInjectableProperties(IReadOnlyCollection members) { StringBuilder stringBuilder = new StringBuilder(); - foreach (ISymbol member in eligibleMembers) + foreach (var member in members) { - if (member is IFieldSymbol fieldSymbol) - stringBuilder.Append(GenerateProperty(fieldSymbol.Type, fieldSymbol.Name)); - - if (member is IPropertySymbol propertySymbol) - stringBuilder.Append(GenerateProperty(propertySymbol.Type, propertySymbol.Name)); + stringBuilder.Append(GenerateProperty(member.TypeDisplay, member.Name)); } return stringBuilder.ToString(); } - private static string GenerateProperty(ITypeSymbol @type, string name) + private static string GenerateProperty(string typeDisplay, string name) { return $@" [global::System.CodeDom.Compiler.GeneratedCode(""Bit.SourceGenerators"",""{BitSourceGeneratorUtil.GetPackageVersion()}"")] @@ -61,6 +44,6 @@ private static string GenerateProperty(ITypeSymbol @type, string name) [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] {"\t\t"}[Inject] {"\t\t"}[EditorBrowsable(EditorBrowsableState.Never)] -{"\t\t"}private {@type} ____{AutoInjectHelper.FormatMemberName(name)} {{ get => {name}; set => {name} = value; }}"; +{"\t\t"}private {typeDisplay} ____{AutoInjectHelper.FormatMemberName(name)} {{ get => {name}; set => {name} = value; }}"; } } diff --git a/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectSourceGenerator.cs b/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectSourceGenerator.cs index 17f697ff45..9cb2034fd0 100644 --- a/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectSourceGenerator.cs +++ b/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectSourceGenerator.cs @@ -1,8 +1,8 @@ -using System; -using System.Collections.Generic; -using System.IO; +using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; using System.Text; +using System.Threading; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -11,119 +11,328 @@ namespace Bit.SourceGenerators; [Generator] -public class AutoInjectSourceGenerator : ISourceGenerator +public class AutoInjectSourceGenerator : IIncrementalGenerator { - private static int counter; - private static readonly DiagnosticDescriptor NonPartialClassError = new DiagnosticDescriptor(id: "BITGEN001", - title: "The class needs to be partial", - messageFormat: "{0} is not partial. The AutoInject attribute needs to be used only in partial classes.", - category: "Bit.SourceGenerators", - DiagnosticSeverity.Error, - isEnabledByDefault: true); - - public void Initialize(GeneratorInitializationContext context) + private static readonly DiagnosticDescriptor NonPartialClassError = new( + id: "BITGEN001", + title: "The class needs to be partial", + messageFormat: "{0} is not partial. The AutoInject attribute needs to be used only in partial classes.", + category: "Bit.SourceGenerators", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public void Initialize(IncrementalGeneratorInitializationContext context) { - context.RegisterForSyntaxNotifications(() => new AutoInjectSyntaxReceiver()); + // Provider 1: fields and properties directly annotated with [AutoInject] + var directMemberProvider = context.SyntaxProvider + .ForAttributeWithMetadataName( + AutoInjectHelper.AutoInjectAttributeFullName, + predicate: static (node, _) => node is FieldDeclarationSyntax or PropertyDeclarationSyntax, + transform: static (ctx, ct) => TransformDirectMember(ctx, ct)) + .Where(static e => e is not null) + .Select(static (e, _) => e!.Value); + + // Provider 2: classes whose base type uses [AutoInject] but they don't (including non-partial, to report diagnostic) + var derivedClassProvider = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (node, _) => node is ClassDeclarationSyntax, + transform: static (ctx, ct) => TransformDerivedClass(ctx, ct)) + .Where(static e => e is not null) + .Select(static (e, _) => e!.Value); + + var combined = directMemberProvider.Collect() + .Combine(derivedClassProvider.Collect()); + + context.RegisterSourceOutput(combined, static (spc, inputs) => Execute(spc, inputs.Left, inputs.Right)); } - public void Execute(GeneratorExecutionContext context) + // ── Data models ────────────────────────────────────────────────────────── + + private readonly record struct LocationInfo( + string FilePath, + int SpanStart, + int SpanLength, + int StartLine, + int StartChar, + int EndLine, + int EndChar); + + private readonly record struct DirectEntry( + string ContainingTypeFullName, + string ClassName, + string ClassNameForCode, + string ClassNamespace, + AutoInjectClassType ClassType, + bool IsPartial, + AutoInjectMember Member, + // Base class members encoded as "F\tname\ttype\tnullable|..." for structural equality + string EncodedBaseMembers, + LocationInfo? ClassLocation); + + private readonly record struct DerivedEntry( + string ContainingTypeFullName, + string ClassName, + string ClassNameForCode, + string ClassNamespace, + AutoInjectClassType ClassType, + bool IsPartial, + string EncodedBaseMembers, + LocationInfo? ClassLocation); + + // ── Transforms ─────────────────────────────────────────────────────────── + + private static DirectEntry? TransformDirectMember(GeneratorAttributeSyntaxContext ctx, CancellationToken ct) { - if (context.SyntaxContextReceiver is not AutoInjectSyntaxReceiver receiver) - return; + var symbol = ctx.TargetSymbol; + if (symbol is not (IFieldSymbol or IPropertySymbol)) return null; - INamedTypeSymbol? attributeSymbol = context.Compilation.GetTypeByMetadataName(AutoInjectHelper.AutoInjectAttributeFullName); + var containingType = symbol.ContainingType; + if (containingType is null) return null; - foreach (IGrouping group in receiver.EligibleMembers.GroupBy(f => f.ContainingType, SymbolEqualityComparer.Default)) + // Filter out nested types + if (!containingType.ContainingSymbol.Equals(containingType.ContainingNamespace, SymbolEqualityComparer.Default)) + return null; + + var attrSymbol = ctx.SemanticModel.Compilation.GetTypeByMetadataName(AutoInjectHelper.AutoInjectAttributeFullName); + + AutoInjectMember member; + if (symbol is IFieldSymbol f) + member = new AutoInjectMember(f.Name, f.Type.ToDisplayString(), IsField: true, IsNullable: f.NullableAnnotation is NullableAnnotation.Annotated); + else { - if (IsClassIsPartial(context, group.Key) is false) - return; + var p = (IPropertySymbol)symbol; + member = new AutoInjectMember(p.Name, p.Type.ToDisplayString(), IsField: false, IsNullable: p.NullableAnnotation is NullableAnnotation.Annotated); + } - string? partialClassSource = GenerateSource(attributeSymbol, group.Key, group.ToList()); + var baseMembers = attrSymbol is null + ? (IReadOnlyCollection)new List() + : AutoInjectHelper.GetBaseClassEligibleMembers(containingType, attrSymbol); - if (string.IsNullOrEmpty(partialClassSource) is false) + var isPartial = IsSymbolPartial(containingType); + var classType = IsRazorComponent(containingType) ? AutoInjectClassType.RazorComponent : AutoInjectClassType.NormalClass; + + LocationInfo? classLocation = null; + foreach (var syntaxRef in containingType.DeclaringSyntaxReferences) + { + if (syntaxRef.GetSyntax() is ClassDeclarationSyntax classDecl) { - context.AddSource($"{group.Key.Name}_{counter++}_autoInject.g.cs", SourceText.From(partialClassSource!, Encoding.UTF8)); + classLocation = GetLocationInfo(classDecl.Identifier); + break; } } - foreach (var @class in receiver.EligibleClassesWithBaseClassUsedAutoInject) - { - if (IsClassIsPartial(context, @class) is false) - return; + return new DirectEntry( + ContainingTypeFullName: containingType.ToDisplayString(), + ClassName: containingType.Name, + ClassNameForCode: AutoInjectHelper.GenerateClassName(containingType), + ClassNamespace: containingType.ContainingNamespace.ToDisplayString(), + ClassType: classType, + IsPartial: isPartial, + Member: member, + EncodedBaseMembers: EncodeMembers(baseMembers), + ClassLocation: classLocation); + } - if (IsClassIsPartial(context, @class.BaseType!) is false) - return; + private static DerivedEntry? TransformDerivedClass(GeneratorSyntaxContext ctx, CancellationToken ct) + { + var classDecl = (ClassDeclarationSyntax)ctx.Node; + var classSymbol = ctx.SemanticModel.GetDeclaredSymbol(classDecl, ct); + if (classSymbol is null) return null; + + if (classSymbol.BaseType is null) return null; + if (classSymbol.BaseType.ToDisplayString() == "System.Object") return null; + + // Filter out nested types + if (!classSymbol.ContainingSymbol.Equals(classSymbol.ContainingNamespace, SymbolEqualityComparer.Default)) + return null; + + var attrFqn = AutoInjectHelper.AutoInjectAttributeFullName; - string? partialClassSource = GenerateSource(attributeSymbol, @class, new List()); + var attrSymbol = ctx.SemanticModel.Compilation.GetTypeByMetadataName(attrFqn); + if (attrSymbol is null) return null; - if (string.IsNullOrEmpty(partialClassSource) is false) + var baseMembers = AutoInjectHelper.GetBaseClassEligibleMembers(classSymbol, attrSymbol); + if (baseMembers.Count == 0) return null; + + var isCurrentClassUseAutoInject = classSymbol + .GetMembers() + .Any(m => (m.Kind == SymbolKind.Field || m.Kind == SymbolKind.Property) && + m.GetAttributes().Any(a => a.AttributeClass?.ToDisplayString() == attrFqn)); + + // Let the direct-member provider handle classes that have their own [AutoInject] members + if (isCurrentClassUseAutoInject) return null; + var classType = IsRazorComponent(classSymbol) ? AutoInjectClassType.RazorComponent : AutoInjectClassType.NormalClass; + + LocationInfo? classLocation = null; + foreach (var syntaxRef in classSymbol.DeclaringSyntaxReferences) + { + if (syntaxRef.GetSyntax() is ClassDeclarationSyntax classDecl2) { - context.AddSource($"{@class.Name}_{counter++}_autoInject.g.cs", SourceText.From(partialClassSource!, Encoding.UTF8)); + classLocation = GetLocationInfo(classDecl2.Identifier); + break; } } + + return new DerivedEntry( + ContainingTypeFullName: classSymbol.ToDisplayString(), + ClassName: classSymbol.Name, + ClassNameForCode: AutoInjectHelper.GenerateClassName(classSymbol), + ClassNamespace: classSymbol.ContainingNamespace.ToDisplayString(), + ClassType: classType, + IsPartial: IsSymbolPartial(classSymbol), + EncodedBaseMembers: EncodeMembers(baseMembers), + ClassLocation: classLocation); } - private static bool IsClassIsPartial(GeneratorExecutionContext context, INamedTypeSymbol @class) + // ── Code generation ─────────────────────────────────────────────────────── + + private static void Execute( + SourceProductionContext spc, + ImmutableArray directEntries, + ImmutableArray derivedEntries) { - var syntaxReferences = @class.DeclaringSyntaxReferences; - foreach (var refrence in syntaxReferences) + // Group direct entries by class + var directGroups = directEntries + .GroupBy(e => e.ContainingTypeFullName) + .ToDictionary(g => g.Key, g => g.ToList()); + + // Emit one file per class that has direct [AutoInject] members + foreach (var kvp in directGroups) + { + var fullName = kvp.Key; + var entries = kvp.Value; + var first = entries[0]; + + if (!first.IsPartial) + { + var loc = first.ClassLocation.HasValue ? ToLocation(first.ClassLocation.Value) : Location.None; + spc.ReportDiagnostic(Diagnostic.Create(NonPartialClassError, loc, first.ClassName)); + continue; + } + + var directMembers = entries.Select(e => e.Member).OrderBy(m => m.Name).ToList(); + var baseMembers = DecodeMembers(first.EncodedBaseMembers); + + string? source = first.ClassType == AutoInjectClassType.RazorComponent + ? AutoInjectRazorComponentHandler.Generate(first.ClassNamespace, first.ClassNameForCode, directMembers) + : AutoInjectNormalClassHandler.Generate(first.ClassNamespace, first.ClassNameForCode, first.ClassName, directMembers, baseMembers); + + if (!string.IsNullOrEmpty(source)) + { + var hintName = $"{EscapeForHint(fullName)}_autoInject.g.cs"; + spc.AddSource(hintName, SourceText.From(source!, Encoding.UTF8)); + } + } + + // Emit one file per derived class (pass-through constructor / empty inject list) + foreach (var entry in derivedEntries) { - var classDeclarationSyntax = (ClassDeclarationSyntax)refrence.GetSyntax(); - var classHasPartial = classDeclarationSyntax.Modifiers.Any(o => o.IsKind(SyntaxKind.PartialKeyword)); - if (classHasPartial is false) + // Skip if already handled by the direct provider + if (directGroups.ContainsKey(entry.ContainingTypeFullName)) continue; + + if (!entry.IsPartial) { - context.ReportDiagnostic(Diagnostic.Create(NonPartialClassError, classDeclarationSyntax.GetLocation(), @class.Name)); - return false; + var loc = entry.ClassLocation.HasValue ? ToLocation(entry.ClassLocation.Value) : Location.None; + spc.ReportDiagnostic(Diagnostic.Create(NonPartialClassError, loc, entry.ClassName)); + continue; + } + + var baseMembers = DecodeMembers(entry.EncodedBaseMembers); + var empty = new List(); + + string? source = entry.ClassType == AutoInjectClassType.RazorComponent + ? AutoInjectRazorComponentHandler.Generate(entry.ClassNamespace, entry.ClassNameForCode, empty) + : AutoInjectNormalClassHandler.Generate(entry.ClassNamespace, entry.ClassNameForCode, entry.ClassName, empty, baseMembers); + + if (!string.IsNullOrEmpty(source)) + { + var hintName = $"{EscapeForHint(entry.ContainingTypeFullName)}_autoInject.g.cs"; + spc.AddSource(hintName, SourceText.From(source!, Encoding.UTF8)); } } + } - return true; + // ── Helpers ─────────────────────────────────────────────────────────────── + + private static bool IsRazorComponent(INamedTypeSymbol @class) + { + // Use interface check only — avoids File.Exists() I/O which is forbidden in incremental transforms + return @class.AllInterfaces.Any(o => o.ToDisplayString() == "Microsoft.AspNetCore.Components.IComponent"); } - private static string? GenerateSource(INamedTypeSymbol? attributeSymbol, INamedTypeSymbol? classSymbol, IReadOnlyCollection eligibleMembers) + private static bool IsSymbolPartial(INamedTypeSymbol classSymbol) { - AutoInjectClassType env = FigureOutTypeOfEnvironment(classSymbol); - return env switch + foreach (var syntaxRef in classSymbol.DeclaringSyntaxReferences) { - AutoInjectClassType.NormalClass => AutoInjectNormalClassHandler.Generate(attributeSymbol, classSymbol, eligibleMembers), - AutoInjectClassType.RazorComponent => AutoInjectRazorComponentHandler.Generate(classSymbol, eligibleMembers), - _ => string.Empty - }; + if (syntaxRef.GetSyntax() is ClassDeclarationSyntax cls && + cls.Modifiers.Any(m => m.IsKind(SyntaxKind.PartialKeyword))) + return true; + } + return false; } - private static AutoInjectClassType FigureOutTypeOfEnvironment(INamedTypeSymbol? @class) + // Format per member: "F\tname\ttype\t0" or "P\tname\ttype\t1", separated by "|" + // Tab (\t) is used to separate fields; type display strings never contain tabs. + private static string EncodeMembers(IEnumerable members) { - if (@class is null) - throw new ArgumentNullException(nameof(@class)); - - if (IsClassIsRazorComponent(@class)) - return AutoInjectClassType.RazorComponent; - else - return AutoInjectClassType.NormalClass; + var sb = new StringBuilder(); + foreach (var m in members) + { + if (sb.Length > 0) sb.Append('|'); + if (m is IFieldSymbol f) + sb.Append('F').Append('\t').Append(f.Name).Append('\t').Append(f.Type.ToDisplayString()).Append('\t').Append(f.NullableAnnotation is NullableAnnotation.Annotated ? '1' : '0'); + else if (m is IPropertySymbol p) + sb.Append('P').Append('\t').Append(p.Name).Append('\t').Append(p.Type.ToDisplayString()).Append('\t').Append(p.NullableAnnotation is NullableAnnotation.Annotated ? '1' : '0'); + } + return sb.ToString(); } - private static bool IsClassIsRazorComponent(INamedTypeSymbol @class) + private static List DecodeMembers(string encoded) { - bool isInheritIComponent = @class.AllInterfaces.Any(o => o.ToDisplayString() == "Microsoft.AspNetCore.Components.IComponent"); + var result = new List(); + if (string.IsNullOrEmpty(encoded)) return result; + + foreach (var part in encoded.Split('|')) + { + // format: "F\tname\ttype\tnullable" (4 tab-separated fields) + var fields = part.Split('\t'); + if (fields.Length < 4) continue; + var kind = fields[0][0]; + var name = fields[1]; + var typeDisplay = fields[2]; + var isNullable = fields[3] == "1"; + result.Add(new AutoInjectMember(name, typeDisplay, IsField: kind == 'F', IsNullable: isNullable)); + } - if (isInheritIComponent) - return true; + return result; + } - var classFilePaths = @class.Locations - .Where(o => o.SourceTree is not null) - .Select(o => o.SourceTree?.FilePath) - .ToList(); + private static string EscapeForHint(string fullyQualifiedName) + => fullyQualifiedName.Replace('<', '[').Replace('>', ']').Replace(' ', '_'); - string razorFileName = $"{@class.Name}.razor"; + private static LocationInfo? GetLocationInfo(SyntaxToken token) + { + var location = token.GetLocation(); - foreach (var path in classFilePaths) - { - string directoryPath = Path.GetDirectoryName(path) ?? string.Empty; - string filePath = Path.Combine(directoryPath, razorFileName); - if (File.Exists(filePath)) - return true; - } + if (location.SourceTree is null) return null; - return false; + var lineSpan = location.GetLineSpan(); + + return new LocationInfo( + FilePath: location.SourceTree.FilePath, + SpanStart: location.SourceSpan.Start, + SpanLength: location.SourceSpan.Length, + StartLine: lineSpan.StartLinePosition.Line, + StartChar: lineSpan.StartLinePosition.Character, + EndLine: lineSpan.EndLinePosition.Line, + EndChar: lineSpan.EndLinePosition.Character); } + + private static Location ToLocation(LocationInfo info) + => Location.Create( + info.FilePath, + new TextSpan(info.SpanStart, info.SpanLength), + new LinePositionSpan( + new LinePosition(info.StartLine, info.StartChar), + new LinePosition(info.EndLine, info.EndChar))); } diff --git a/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectSyntaxReceiver.cs b/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectSyntaxReceiver.cs deleted file mode 100644 index c6afd50a48..0000000000 --- a/src/SourceGenerators/Bit.SourceGenerators/AutoInject/AutoInjectSyntaxReceiver.cs +++ /dev/null @@ -1,104 +0,0 @@ -using System; -using System.Collections.ObjectModel; -using System.Linq; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; - -namespace Bit.SourceGenerators; - -public class AutoInjectSyntaxReceiver : ISyntaxContextReceiver -{ - public Collection EligibleMembers { get; } = []; - public Collection EligibleClassesWithBaseClassUsedAutoInject { get; } = []; - - public void OnVisitSyntaxNode(GeneratorSyntaxContext context) - { - try - { - MarkEligibleFields(context); - MarkEligibleProperties(context); - MarkEligibleClasses(context); - } - catch (Exception exp) - { - throw new InvalidOperationException($"Error processing {context.Node.SyntaxTree.FilePath}: {exp}", exp); - } - } - - private void MarkEligibleClasses(GeneratorSyntaxContext context) - { - if (context.Node is not ClassDeclarationSyntax classDeclarationSyntax) - return; - - var classSymbol = context.SemanticModel.GetDeclaredSymbol(classDeclarationSyntax); - - if (classSymbol == null) - return; - - if (classSymbol.BaseType == null) - return; - - if (classSymbol.BaseType.ToDisplayString() == "System.Object") - return; - - var isBaseTypeUseAutoInject = classSymbol.BaseType - .GetMembers() - .Any(m => - (m.Kind == SymbolKind.Field || m.Kind == SymbolKind.Property) && - m.GetAttributes() - .Any(a => a.AttributeClass != null && - a.AttributeClass.ToDisplayString() == AutoInjectHelper.AutoInjectAttributeFullName)); - - var isCurrentClassUseAutoInject = classSymbol - .GetMembers() - .Any(m => - (m.Kind == SymbolKind.Field || m.Kind == SymbolKind.Property) && - m.GetAttributes() - .Any(a => a.AttributeClass != null && - a.AttributeClass.ToDisplayString() == AutoInjectHelper.AutoInjectAttributeFullName)); - - if (isBaseTypeUseAutoInject && (isCurrentClassUseAutoInject is false)) - EligibleClassesWithBaseClassUsedAutoInject.Add(classSymbol); - } - - private void MarkEligibleFields(GeneratorSyntaxContext context) - { - if (context.Node is not FieldDeclarationSyntax fieldDeclarationSyntax || fieldDeclarationSyntax.AttributeLists.Any() is false) - return; - - if (fieldDeclarationSyntax.Parent is not ClassDeclarationSyntax classDeclarationSyntax || classDeclarationSyntax is null) - return; - - foreach (VariableDeclaratorSyntax variable in fieldDeclarationSyntax.Declaration.Variables) - { - var fieldSymbol = ModelExtensions.GetDeclaredSymbol(context.SemanticModel, variable) as IFieldSymbol; - if (fieldSymbol is not null && - fieldSymbol.GetAttributes() - .Any(ad => ad.AttributeClass is not null && - ad.AttributeClass.ToDisplayString() == AutoInjectHelper.AutoInjectAttributeFullName)) - { - EligibleMembers.Add(fieldSymbol); - } - } - } - - private void MarkEligibleProperties(GeneratorSyntaxContext context) - { - if (context.Node is not PropertyDeclarationSyntax propertyDeclarationSyntax || propertyDeclarationSyntax.AttributeLists.Count <= 0) - return; - - if (propertyDeclarationSyntax.Parent is not ClassDeclarationSyntax classDeclarationSyntax || classDeclarationSyntax is null) - return; - - var propertySymbol = context.SemanticModel.GetDeclaredSymbol(propertyDeclarationSyntax); - - if (propertySymbol is null) - return; - - if (propertySymbol.GetAttributes().Any(ad => ad.AttributeClass is not null && ad.AttributeClass.ToDisplayString() == AutoInjectHelper.AutoInjectAttributeFullName)) - { - EligibleMembers.Add(propertySymbol); - } - } -} diff --git a/src/SourceGenerators/Bit.SourceGenerators/Blazor/BitProperty.cs b/src/SourceGenerators/Bit.SourceGenerators/Blazor/BitProperty.cs index e1ec3043f9..ed049b09a0 100644 --- a/src/SourceGenerators/Bit.SourceGenerators/Blazor/BitProperty.cs +++ b/src/SourceGenerators/Bit.SourceGenerators/Blazor/BitProperty.cs @@ -1,15 +1,10 @@ -using Microsoft.CodeAnalysis; +namespace Bit.SourceGenerators; -namespace Bit.SourceGenerators; - -public class BitProperty -{ - public BitProperty(IPropertySymbol propertySymbol, bool isTwoWayBoundProperty) - { - PropertySymbol = propertySymbol; - IsTwoWayBoundProperty = isTwoWayBoundProperty; - } - - public IPropertySymbol PropertySymbol { get; set; } - public bool IsTwoWayBoundProperty { get; set; } -} +internal readonly record struct BitProperty( + string ContainingTypeFullName, + string ClassName, + string ClassNameForCode, + string ClassNamespace, + bool IsBaseTypeComponentBase, + string PropertyName, + string PropertyType); diff --git a/src/SourceGenerators/Bit.SourceGenerators/Blazor/BlazorParameterPropertySyntaxReceiver.cs b/src/SourceGenerators/Bit.SourceGenerators/Blazor/BlazorParameterPropertySyntaxReceiver.cs deleted file mode 100644 index 577255efb9..0000000000 --- a/src/SourceGenerators/Bit.SourceGenerators/Blazor/BlazorParameterPropertySyntaxReceiver.cs +++ /dev/null @@ -1,49 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; - -namespace Bit.SourceGenerators; - -public class BlazorParameterPropertySyntaxReceiver : ISyntaxContextReceiver -{ - public IList Properties { get; } = []; - - public void OnVisitSyntaxNode(GeneratorSyntaxContext context) - { - if (context.Node is PropertyDeclarationSyntax propertyDeclarationSyntax - && propertyDeclarationSyntax.AttributeLists.Any()) - { - try - { - var classDeclarationSyntax = propertyDeclarationSyntax.Parent as ClassDeclarationSyntax; - - if (classDeclarationSyntax?.Modifiers.Any(k => k.IsKind(SyntaxKind.PartialKeyword)) is false) - return; - - var propertySymbol = context.SemanticModel.GetDeclaredSymbol(propertyDeclarationSyntax); - - if (propertySymbol is null) return; - - var type = propertySymbol.ContainingType; - - if (type == null) return; - - if (type.GetMembers().Any(m => m.Name == "SetParametersAsync")) return; - - - if (propertySymbol.GetAttributes().Any(ad => ad.AttributeClass?.ToDisplayString() == "Microsoft.AspNetCore.Components.ParameterAttribute" - || ad.AttributeClass?.ToDisplayString() == "Microsoft.AspNetCore.Components.CascadingParameterAttribute")) - { - Properties.Add(propertySymbol); - } - } - catch (Exception exp) - { - throw new InvalidOperationException($"Error processing property {propertyDeclarationSyntax.Identifier.Text} in {propertyDeclarationSyntax.SyntaxTree.FilePath}", exp); - } - } - } -} diff --git a/src/SourceGenerators/Bit.SourceGenerators/Blazor/BlazorSetParametersSourceGenerator.cs b/src/SourceGenerators/Bit.SourceGenerators/Blazor/BlazorSetParametersSourceGenerator.cs index 60f414ab60..ffe7578479 100644 --- a/src/SourceGenerators/Bit.SourceGenerators/Blazor/BlazorSetParametersSourceGenerator.cs +++ b/src/SourceGenerators/Bit.SourceGenerators/Blazor/BlazorSetParametersSourceGenerator.cs @@ -1,40 +1,88 @@ using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; using System.Text; +using System.Threading; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; namespace Bit.SourceGenerators; [Generator] -public class BlazorSetParametersSourceGenerator : ISourceGenerator +public class BlazorSetParametersSourceGenerator : IIncrementalGenerator { - public void Execute(GeneratorExecutionContext context) + public void Initialize(IncrementalGeneratorInitializationContext context) { - if (context.SyntaxContextReceiver is not BlazorParameterPropertySyntaxReceiver receiver) - return; + var parameterProvider = context.SyntaxProvider + .ForAttributeWithMetadataName( + "Microsoft.AspNetCore.Components.ParameterAttribute", + predicate: static (node, _) => IsPartialClassProperty(node), + transform: static (ctx, ct) => ExtractBitProperty(ctx, ct)) + .Where(static p => p is not null) + .Select(static (p, _) => p!.Value); + + var cascadingProvider = context.SyntaxProvider + .ForAttributeWithMetadataName( + "Microsoft.AspNetCore.Components.CascadingParameterAttribute", + predicate: static (node, _) => IsPartialClassProperty(node), + transform: static (ctx, ct) => ExtractBitProperty(ctx, ct)) + .Where(static p => p is not null) + .Select(static (p, _) => p!.Value); + + var combined = parameterProvider.Collect() + .Combine(cascadingProvider.Collect()) + .Select(static (pair, _) => pair.Left.AddRange(pair.Right)); + + context.RegisterSourceOutput(combined, static (spc, properties) => Execute(spc, properties)); + } - foreach (var group in receiver.Properties.GroupBy(symbol => symbol.ContainingType, SymbolEqualityComparer.Default)) - { - var properties = group.Select(p => new BitProperty(p, false)).ToList(); - //CheckTwoWayBoundParameter(properties); + private static bool IsPartialClassProperty(SyntaxNode node) + { + return node is PropertyDeclarationSyntax prop && + prop.Parent is (ClassDeclarationSyntax or RecordDeclarationSyntax) and TypeDeclarationSyntax typeDecl && + typeDecl.Modifiers.Any(m => m.IsKind(SyntaxKind.PartialKeyword)); + } - if (group.Key == null) continue; + private static BitProperty? ExtractBitProperty(GeneratorAttributeSyntaxContext ctx, CancellationToken ct) + { + if (ctx.TargetSymbol is not IPropertySymbol prop) return null; - string classSource = GeneratePartialClassToOverrideSetParameters((INamedTypeSymbol)group.Key, properties); - context.AddSource($"{group.Key.Name}_SetParametersAsync.AutoGenerated.cs", SourceText.From(classSource, Encoding.UTF8)); - } + var containingType = prop.ContainingType; + if (containingType is null) return null; + + if (containingType.GetMembers().Any(m => m.Name == "SetParametersAsync")) return null; + + var classNameForCode = BuildClassNameForCode(containingType); + var isBaseTypeComponentBase = containingType.BaseType?.ToDisplayString() == "Microsoft.AspNetCore.Components.ComponentBase"; + + return new BitProperty( + ContainingTypeFullName: containingType.ToDisplayString(), + ClassName: containingType.Name, + ClassNameForCode: classNameForCode, + ClassNamespace: containingType.ContainingNamespace.ToDisplayString(), + IsBaseTypeComponentBase: isBaseTypeComponentBase, + PropertyName: prop.Name, + PropertyType: prop.Type.ToDisplayString()); } - public void Initialize(GeneratorInitializationContext context) + private static void Execute(SourceProductionContext spc, ImmutableArray properties) { - context.RegisterForSyntaxNotifications(() => new BlazorParameterPropertySyntaxReceiver()); + foreach (var group in properties.GroupBy(p => p.ContainingTypeFullName)) + { + var list = group.ToList(); + var first = list[0]; + string source = GeneratePartialClassToOverrideSetParameters(first, list); + spc.AddSource($"{EscapeForHint(first.ContainingTypeFullName)}_SetParametersAsync.AutoGenerated.cs", SourceText.From(source, Encoding.UTF8)); + } } - private static string GeneratePartialClassToOverrideSetParameters(INamedTypeSymbol classSymbol, List properties) + private static string GeneratePartialClassToOverrideSetParameters(BitProperty classInfo, List properties) { - string namespaceName = classSymbol.ContainingNamespace.ToDisplayString(); - bool isBase = classSymbol.BaseType?.ToDisplayString() == "Microsoft.AspNetCore.Components.ComponentBase"; + var namespaceName = classInfo.ClassNamespace; + var className = classInfo.ClassNameForCode; + var isBase = classInfo.IsBaseTypeComponentBase; StringBuilder source = new StringBuilder($@"using System; using System.Threading.Tasks; @@ -44,7 +92,7 @@ private static string GeneratePartialClassToOverrideSetParameters(INamedTypeSymb namespace {namespaceName} {{ - public partial class {GetClassName(classSymbol)} + public partial class {className} {{ [global::System.CodeDom.Compiler.GeneratedCode(""Bit.SourceGenerators"",""{BitSourceGeneratorUtil.GetPackageVersion()}"")] [global::System.Diagnostics.DebuggerNonUserCode] @@ -52,24 +100,15 @@ public partial class {GetClassName(classSymbol)} public override Task SetParametersAsync(ParameterView parameters) {{ "); - //foreach (var property in properties.Where(p => p.IsTwoWayBoundProperty)) - //{ - // source.AppendLine($" {property.PropertySymbol.Name}HasBeenSet = false;"); - //} source.AppendLine(" foreach (var parameter in parameters)"); source.AppendLine(" {"); source.AppendLine(" switch (parameter.Name)"); source.AppendLine(" {"); - // create cases for each property foreach (var bitProperty in properties) { - source.AppendLine($" case nameof({bitProperty.PropertySymbol.Name}):"); - //if (bitProperty.IsTwoWayBoundProperty) - //{ - // source.AppendLine($" {bitProperty.PropertySymbol.Name}HasBeenSet = true;"); - //} - source.AppendLine($" {bitProperty.PropertySymbol.Name} = parameter.Value is null ? default! : ({bitProperty.PropertySymbol.Type.ToDisplayString()})parameter.Value;"); + source.AppendLine($" case nameof({bitProperty.PropertyName}):"); + source.AppendLine($" {bitProperty.PropertyName} = parameter.Value is null ? default! : ({bitProperty.PropertyType})parameter.Value;"); source.AppendLine(" break;"); } @@ -92,27 +131,16 @@ public override Task SetParametersAsync(ParameterView parameters) return source.ToString(); } - private static string GetClassName(INamedTypeSymbol classSymbol) + private static string BuildClassNameForCode(INamedTypeSymbol classSymbol) { - StringBuilder sbName = new StringBuilder(classSymbol.Name); - if (classSymbol.IsGenericType) { - sbName.Append('<'); - sbName.Append(string.Join(", ", classSymbol.TypeArguments.Select(s => s.Name))); - sbName.Append('>'); + var typeArgs = string.Join(", ", classSymbol.TypeArguments.Select(s => s.Name)); + return $"{classSymbol.Name}<{typeArgs}>"; } - - return sbName.ToString(); + return classSymbol.Name; } - //private static void CheckTwoWayBoundParameter(List properties) - //{ - // foreach (var item in properties) - // { - // var propName = $"{item.PropertySymbol.Name}Changed"; - // var propType = $"Microsoft.AspNetCore.Components.EventCallback<{item.PropertySymbol.Type.ToDisplayString()}>"; - // item.IsTwoWayBoundProperty = properties.Any(p => p.PropertySymbol.Name == propName && p.PropertySymbol.Type.ToDisplayString() == propType); - // } - //} + private static string EscapeForHint(string fullyQualifiedName) + => fullyQualifiedName.Replace('<', '[').Replace('>', ']').Replace(' ', '_'); } diff --git a/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/ActionParameter.cs b/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/ActionParameter.cs deleted file mode 100644 index 5e7561332a..0000000000 --- a/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/ActionParameter.cs +++ /dev/null @@ -1,13 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; -using Microsoft.CodeAnalysis; - -namespace Bit.SourceGenerators; - -public class ActionParameter -{ - public string Name { get; set; } = default!; - - public ITypeSymbol Type { get; set; } = default!; -} diff --git a/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/ControllerAction.cs b/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/ControllerAction.cs deleted file mode 100644 index 2949fcd5fa..0000000000 --- a/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/ControllerAction.cs +++ /dev/null @@ -1,29 +0,0 @@ -using System.Collections.Generic; -using Microsoft.CodeAnalysis; - -namespace Bit.SourceGenerators; - -public class ControllerAction -{ - public IMethodSymbol Method { get; set; } = default!; - - public ITypeSymbol ReturnType { get; set; } = default!; - - public bool DoesReturnSomething => ReturnType.ToDisplayString() is not "System.Threading.Tasks.Task" or "System.Threading.Tasks.ValueTask"; - - public bool DoesReturnString => DoesReturnSomething && ReturnType.ToDisplayString() is "System.Threading.Tasks.Task" or "System.Threading.Tasks.ValueTask"; - - public bool DoesReturnIAsyncEnumerable => DoesReturnSomething && ReturnType.ToDisplayString().Contains("IAsyncEnumerable"); - - public string HttpMethod { get; set; } = default!; - - public string Url { get; set; } = default!; - - public List Parameters { get; set; } = []; - - public ActionParameter? BodyParameter { get; set; } - - public bool HasCancellationToken => string.IsNullOrEmpty(CancellationTokenParameterName) is false; - - public string? CancellationTokenParameterName { get; set; } -} diff --git a/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/ControllerEntry.cs b/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/ControllerEntry.cs new file mode 100644 index 0000000000..7aab390d3e --- /dev/null +++ b/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/ControllerEntry.cs @@ -0,0 +1,18 @@ +namespace Bit.SourceGenerators; + +/// +/// Flat, structurally equatable representation of one IAppController interface captured during +/// the incremental generator transform phase. All Roslyn symbol data is serialised to strings +/// so that the record struct has correct value-based equality for incremental caching. +/// +/// Encoding separators (ASCII control characters — never appear in C# identifiers or type names): +/// \x1E RS – between action records +/// \x1F US – between fields inside one action record +/// \x1D GS – between parameters inside one action record +/// \x1C FS – between sub-fields inside one parameter entry +/// +internal readonly record struct ControllerEntry( + string SymbolDisplay, // e.g. "IMyController" + string SymbolDisplayNoNull, // same without nullable annotation + string ClassName, // e.g. "MyController" (used as class name of the generated proxy) + string EncodedActions); // all action data encoded with the separators above diff --git a/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/HttpClientProxySourceGenerator.cs b/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/HttpClientProxySourceGenerator.cs index a89fff1b14..3c22dfd3b1 100644 --- a/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/HttpClientProxySourceGenerator.cs +++ b/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/HttpClientProxySourceGenerator.cs @@ -1,95 +1,248 @@ using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; using System.Text; +using System.Threading; +using System.Web; +using DoLess.UriTemplates; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; namespace Bit.SourceGenerators; [Generator] -public class HttpClientProxySourceGenerator : ISourceGenerator +public class HttpClientProxySourceGenerator : IIncrementalGenerator { - public void Initialize(GeneratorInitializationContext context) + // ASCII control-character separators (never appear in C# identifiers or type display strings) + private const char ActionSep = '\x1E'; // RS – between action records + private const char FieldSep = '\x1F'; // US – between fields inside one action record + private const char ParamSep = '\x1D'; // GS – between parameters inside one action record + private const char SubFieldSep = '\x1C'; // FS – between sub-fields inside one parameter entry + + public void Initialize(IncrementalGeneratorInitializationContext context) { - context.RegisterForSyntaxNotifications(() => new HttpClientProxySyntaxReceiver()); + var controllerProvider = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (node, _) => node is InterfaceDeclarationSyntax iface && + iface.BaseList is not null && + iface.BaseList.Types.Any(t => t.Type.ToString() == "IAppController"), + transform: static (ctx, ct) => TransformController(ctx, ct)) + .Where(static c => c is not null) + .Select(static (c, _) => c!.Value); + + context.RegisterSourceOutput(controllerProvider.Collect(), static (spc, controllers) => Execute(spc, controllers)); } - public void Execute(GeneratorExecutionContext context) + // ── Transform ───────────────────────────────────────────────────────────── + + private static ControllerEntry? TransformController(GeneratorSyntaxContext ctx, CancellationToken ct) { - if (context.SyntaxContextReceiver is not HttpClientProxySyntaxReceiver receiver || receiver.IControllers.Any() is false) + var interfaceDecl = (InterfaceDeclarationSyntax)ctx.Node; + var model = ctx.SemanticModel; + var controllerSymbol = model.GetDeclaredSymbol(interfaceDecl, ct) as ITypeSymbol; + if (controllerSymbol is null) return null; + if (!controllerSymbol.IsIController()) return null; + + var controllerName = controllerSymbol.Name[1..].Replace("Controller", string.Empty); + + var route = controllerSymbol + .GetAttributes() + .FirstOrDefault(a => a.AttributeClass?.Name.StartsWith("Route") is true)? + .ConstructorArguments + .FirstOrDefault() + .Value? + .ToString() + ?.Replace("[controller]", controllerName) ?? string.Empty; + + var stringSpecialType = model.Compilation.GetSpecialType(SpecialType.System_String); + + var actionBuilders = new List(); + + foreach (var method in controllerSymbol.GetMembers().OfType().Where(m => m.MethodKind == MethodKind.Ordinary)) { - return; + ct.ThrowIfCancellationRequested(); + + var httpMethod = method.GetHttpMethod(); + + // Build URL from route template + var actionSpecificRoute = method + .GetAttributes() + .FirstOrDefault(a => a.AttributeClass?.Name.StartsWith("Route") is true)? + .ConstructorArguments + .FirstOrDefault() + .Value? + .ToString() + ?.Replace("[controller]", controllerName) + ?.Replace("~/", string.Empty); + + var uriTemplate = UriTemplate.For( + $"{actionSpecificRoute ?? route}{method.GetAttributes() + .FirstOrDefault(a => a.AttributeClass?.Name.StartsWith("Http") is true)? + .ConstructorArguments.FirstOrDefault().Value?.ToString()}" + .Replace("[action]", method.Name)); + + var rawParameters = method.Parameters.Select(y => (y.Name, Type: y.Type)).ToList(); + foreach (var (pName, _) in rawParameters) + uriTemplate.WithParameter(pName, $"{{{pName}}}"); + + string url = HttpUtility.UrlDecode(uriTemplate.ExpandToString()).TrimEnd('/'); + + var ctParam = rawParameters.FirstOrDefault(p => p.Type.ToDisplayString() == "System.Threading.CancellationToken"); + var ctName = ctParam == default ? null : ctParam.Name; + + var bodyParam = rawParameters.FirstOrDefault(p => + p.Type.ToDisplayString() is not "System.Threading.CancellationToken" && + !url.Contains($"{{{p.Name}}}")); + + var returnType = method.ReturnType; + var returnDisplay = returnType.ToDisplayString(); + bool doesReturnSomething = returnDisplay is not ("System.Threading.Tasks.Task" or "System.Threading.Tasks.ValueTask"); + bool doesReturnString = doesReturnSomething && returnDisplay is "System.Threading.Tasks.Task" or "System.Threading.Tasks.ValueTask"; + bool doesReturnIAsyncEnum = doesReturnSomething && returnDisplay.Contains("IAsyncEnumerable"); + var returnUnderlyingNoNull = returnType.GetUnderlyingType().ToDisplayString(NullableFlowState.None); + + // Encode parameters: "name\x1CfullType\x1CtypeNoNull\x1CisString" joined by \x1D + var encodedParams = string.Join( + ParamSep.ToString(), + rawParameters.Select(p => + $"{p.Name}{SubFieldSep}{p.Type.ToDisplayString()}{SubFieldSep}{p.Type.ToDisplayString(NullableFlowState.None)}{SubFieldSep}{(SymbolEqualityComparer.Default.Equals(p.Type, stringSpecialType) ? "1" : "0")}")); + + // Action fields joined by \x1F + actionBuilders.Add(string.Join( + FieldSep.ToString(), + method.Name, + returnDisplay, + returnUnderlyingNoNull, + doesReturnSomething ? "1" : "0", + doesReturnString ? "1" : "0", + doesReturnIAsyncEnum ? "1" : "0", + httpMethod, + url, + ctName is not null ? "1" : "0", + ctName ?? "", + encodedParams, + bodyParam == default ? "" : bodyParam.Name, + bodyParam == default ? "" : bodyParam.Type.ToDisplayString(NullableFlowState.None))); } + return new ControllerEntry( + SymbolDisplay: controllerSymbol.ToDisplayString(), + SymbolDisplayNoNull: controllerSymbol.ToDisplayString(NullableFlowState.None), + ClassName: controllerSymbol.Name[1..], + EncodedActions: string.Join(ActionSep.ToString(), actionBuilders)); + } + + // ── Code generation ─────────────────────────────────────────────────────── + + private static void Execute(SourceProductionContext spc, ImmutableArray controllers) + { + if (controllers.IsEmpty) return; + StringBuilder generatedClasses = new(); - foreach (var iController in receiver.IControllers) + foreach (var controller in controllers) { StringBuilder generatedMethods = new(); - foreach (var action in iController.Actions) + foreach (var actionEncoded in controller.EncodedActions.Split(ActionSep)) { - string parameters = string.Join(", ", action.Parameters.Select(p => $"{p.Type.ToDisplayString()} {p.Name}")); + if (string.IsNullOrEmpty(actionEncoded)) continue; - var hasQueryString = action.Url.Contains('?'); + var fields = actionEncoded.Split(FieldSep); + // fields[0] methodName + // fields[1] returnTypeDisplay + // fields[2] returnTypeUnderlyingNoNull + // fields[3] doesReturnSomething + // fields[4] doesReturnString + // fields[5] doesReturnIAsyncEnumerable + // fields[6] httpMethod + // fields[7] url + // fields[8] hasCancellationToken + // fields[9] ctParamName + // fields[10] encodedParams + // fields[11] bodyParamName + // fields[12] bodyParamTypeNoNull - List jsonReadParametersList = []; - if (action.DoesReturnSomething && action.DoesReturnString is false) - { - jsonReadParametersList.Add($"options.GetTypeInfo<{action.ReturnType.GetUnderlyingType().ToDisplayString()}>()"); - } - if (action.HasCancellationToken) + if (fields.Length < 13) continue; + + var methodName = fields[0]; + var returnTypeDisplay = fields[1]; + var returnUnderlyingNoNull = fields[2]; + var doesReturnSomething = fields[3] == "1"; + var doesReturnString = fields[4] == "1"; + var doesReturnIAsyncEnum = fields[5] == "1"; + var httpMethod = fields[6]; + var url = fields[7]; + var hasCt = fields[8] == "1"; + var ctName = fields[9]; + var bodyParamName = string.IsNullOrEmpty(fields[11]) ? null : fields[11]; + var bodyParamTypeNoNull = string.IsNullOrEmpty(fields[12]) ? null : fields[12]; + + // Decode parameters + var parameters = new List<(string Name, string TypeDisplay, string TypeDisplayNoNull, bool IsString)>(); + if (!string.IsNullOrEmpty(fields[10])) { - jsonReadParametersList.Add(action.CancellationTokenParameterName!); + foreach (var pEnc in fields[10].Split(ParamSep)) + { + var sf = pEnc.Split(SubFieldSep); + if (sf.Length < 4) continue; + parameters.Add((sf[0], sf[1], sf[2], sf[3] == "1")); + } } + + string parameterList = string.Join(", ", parameters.Select(p => $"{p.TypeDisplay} {p.Name}")); + + List jsonReadParametersList = new(); + if (doesReturnSomething && !doesReturnString) + jsonReadParametersList.Add($"options.GetTypeInfo<{returnUnderlyingNoNull}>()"); + if (hasCt) + jsonReadParametersList.Add(ctName!); var jsonReadParameters = string.Join(", ", jsonReadParametersList); var requestOptions = new StringBuilder(); - requestOptions.AppendLine($"__request.Options.TryAdd(\"IControllerType\", typeof({iController.Symbol.ToDisplayString(NullableFlowState.None)}));"); - requestOptions.AppendLine($"__request.Options.TryAdd(\"ActionName\", \"{action.Method.Name}\");"); + requestOptions.AppendLine($"__request.Options.TryAdd(\"IControllerType\", typeof({controller.SymbolDisplayNoNull}));"); + requestOptions.AppendLine($"__request.Options.TryAdd(\"ActionName\", \"{methodName}\");"); requestOptions.AppendLine($@"__request.Options.TryAdd(""ActionParametersInfo"", new Dictionary {{ - {string.Join(", ", action.Parameters.Select(p => $"{{ \"{p.Name}\", typeof({p.Type.ToDisplayString(NullableFlowState.None)}) }}"))} + {string.Join(", ", parameters.Select(p => $"{{ \"{p.Name}\", typeof({p.TypeDisplayNoNull}) }}"))} }});"); - if (action.BodyParameter is not null) - { - requestOptions.AppendLine($"__request.Options.TryAdd(\"RequestType\", typeof({action.BodyParameter.Type.ToDisplayString(NullableFlowState.None)}));"); - } - if (action.DoesReturnSomething) - { - requestOptions.AppendLine($"__request.Options.TryAdd(\"ResponseType\", typeof({action.ReturnType.GetUnderlyingType().ToDisplayString(NullableFlowState.None)}));"); - } - - var stringType = context.Compilation.GetSpecialType(SpecialType.System_String); + if (bodyParamName is not null) + requestOptions.AppendLine($"__request.Options.TryAdd(\"RequestType\", typeof({bodyParamTypeNoNull}));"); + if (doesReturnSomething) + requestOptions.AppendLine($"__request.Options.TryAdd(\"ResponseType\", typeof({returnUnderlyingNoNull}));"); - var encodeStringRouteParameters = string.Join(Environment.NewLine, action.Parameters - .Where(p => SymbolEqualityComparer.Default.Equals(p.Type, stringType)) - .Select(p => $"{p.Name} = Uri.EscapeDataString(Uri.UnescapeDataString({p.Name} ?? string.Empty));")); + var encodeStringRouteParameters = string.Join( + Environment.NewLine, + parameters + .Where(p => p.IsString) + .Select(p => $"{p.Name} = Uri.EscapeDataString(Uri.UnescapeDataString({p.Name} ?? string.Empty));")); generatedMethods.AppendLine($@" - public async {action.ReturnType.ToDisplayString()} {action.Method.Name}({parameters}) + public async {returnTypeDisplay} {methodName}({parameterList}) {{ {encodeStringRouteParameters} - {$@"var __url = $""{action.Url}"";"} + {$@"var __url = $""{url}"";"} var dynamicQS = GetDynamicQueryString(); if (dynamicQS is not null) {{ - __url += {(action.Url.Contains('?') ? "'&'" : "'?'")} + dynamicQS; + __url += {(url.Contains('?') ? "'&'" : "'?'")} + dynamicQS; }} - {(action.DoesReturnSomething ? $@"return (await prerenderStateService.GetValue(__url, async () => + {(doesReturnSomething ? $@"return (await prerenderStateService.GetValue(__url, async () => {{" : string.Empty)} - using var __request = new HttpRequestMessage(HttpMethod.{action.HttpMethod}, __url); + using var __request = new HttpRequestMessage(HttpMethod.{httpMethod}, __url); {requestOptions} - {(action.BodyParameter is not null ? $@"__request.Content = JsonContent.Create({action.BodyParameter.Name}, options.GetTypeInfo<{action.BodyParameter.Type.ToDisplayString()}>());" : string.Empty)} - {(action.DoesReturnIAsyncEnumerable ? "" : "using ")}var __response = await httpClient.SendAsync(__request, HttpCompletionOption.ResponseHeadersRead {(action.HasCancellationToken ? $", {action.CancellationTokenParameterName}" : string.Empty)}); - {(action.DoesReturnSomething ? ($"return {(action.DoesReturnIAsyncEnumerable ? "" : "await")} __response.Content.{(action.DoesReturnIAsyncEnumerable ? "ReadFromJsonAsAsyncEnumerable" : action.DoesReturnString ? "ReadAsStringAsync" : "ReadFromJsonAsync")}({jsonReadParameters});" + + {(bodyParamName is not null ? $@"__request.Content = JsonContent.Create({bodyParamName}, options.GetTypeInfo<{bodyParamTypeNoNull}>());" : string.Empty)} + {(doesReturnIAsyncEnum ? "" : "using ")}var __response = await httpClient.SendAsync(__request, HttpCompletionOption.ResponseHeadersRead {(hasCt ? $", {ctName}" : string.Empty)}); + {(doesReturnSomething ? ($"return {(doesReturnIAsyncEnum ? "" : "await")} __response.Content.{(doesReturnIAsyncEnum ? "ReadFromJsonAsAsyncEnumerable" : doesReturnString ? "ReadAsStringAsync" : "ReadFromJsonAsync")}({jsonReadParameters});" + $"}}))!;") : string.Empty)} }} "); } generatedClasses.AppendLine($@" - internal class {iController.ClassName}(HttpClient httpClient, JsonSerializerOptions options, IPrerenderStateService prerenderStateService) : AppControllerBase, {iController.Symbol.ToDisplayString()} + internal class {controller.ClassName}(HttpClient httpClient, JsonSerializerOptions options, IPrerenderStateService prerenderStateService) : AppControllerBase, {controller.SymbolDisplay} {{ {generatedMethods} }}"); @@ -110,7 +263,7 @@ public static class IHttpClientServiceCollectionExtensions {{ public static void AddTypedHttpClients(this IServiceCollection services) {{ -{string.Join(Environment.NewLine, receiver.IControllers.Select(i => $" services.TryAddTransient<{i.Symbol.ToDisplayString()}, {i.ClassName}>();"))} +{string.Join(Environment.NewLine, controllers.Select(i => $" services.TryAddTransient<{i.SymbolDisplay}, {i.ClassName}>();"))} }} internal class AppControllerBase @@ -144,6 +297,6 @@ public void AddQueryStrings(Dictionary queryString) }} "); - context.AddSource($"HttpClientProxy.cs", finalSource.ToString()); + spc.AddSource("HttpClientProxy.cs", SourceText.From(finalSource.ToString(), Encoding.UTF8)); } } diff --git a/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/HttpClientProxySyntaxReceiver.cs b/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/HttpClientProxySyntaxReceiver.cs deleted file mode 100644 index 25fdba1d80..0000000000 --- a/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/HttpClientProxySyntaxReceiver.cs +++ /dev/null @@ -1,101 +0,0 @@ -using System; -using System.Web; -using System.Linq; -using System.Collections.Generic; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using DoLess.UriTemplates; - -namespace Bit.SourceGenerators; - -public class HttpClientProxySyntaxReceiver : ISyntaxContextReceiver -{ - public List IControllers { get; } = []; - - public void OnVisitSyntaxNode(GeneratorSyntaxContext syntaxNode) - { - try - { - if (syntaxNode.Node is InterfaceDeclarationSyntax interfaceDeclarationSyntax - && interfaceDeclarationSyntax.BaseList is not null - && interfaceDeclarationSyntax.BaseList.Types.Any(t => t.Type.ToString() == "IAppController")) - { - var model = syntaxNode.SemanticModel.Compilation.GetSemanticModel(interfaceDeclarationSyntax.SyntaxTree); - var controllerSymbol = (ITypeSymbol)model.GetDeclaredSymbol(interfaceDeclarationSyntax)!; - bool isController = controllerSymbol.IsIController(); - - if (isController) - { - var controllerName = controllerSymbol.Name[1..].Replace("Controller", string.Empty); - - var route = controllerSymbol - .GetAttributes() - .FirstOrDefault(a => a.AttributeClass?.Name.StartsWith("Route") is true)? - .ConstructorArguments - .FirstOrDefault() - .Value? - .ToString() - ?.Replace("[controller]", controllerName) ?? string.Empty; - - var actions = controllerSymbol.GetMembers() - .OfType() - .Where(m => m.MethodKind == MethodKind.Ordinary) - .Select(m => new ControllerAction - { - Method = m, - ReturnType = m.ReturnType, - HttpMethod = m.GetHttpMethod(), - Url = m.Name, - Parameters = m.Parameters.Select(y => new ActionParameter - { - Name = y.Name, - Type = y.Type - }).ToList() - }).ToList(); - - foreach (var action in actions) - { - var actionSpecificRoute = action.Method - .GetAttributes() - .FirstOrDefault(a => a.AttributeClass?.Name.StartsWith("Route") is true)? - .ConstructorArguments - .FirstOrDefault() - .Value? - .ToString() - ?.Replace("[controller]", controllerName) - ?.Replace("~/", string.Empty); // https://stackoverflow.com/a/34712201 - - var uriTemplate = UriTemplate.For($"{actionSpecificRoute ?? route}{action.Method.GetAttributes() - .FirstOrDefault(a => a.AttributeClass?.Name.StartsWith("Http") is true)? - .ConstructorArguments.FirstOrDefault().Value?.ToString()}".Replace("[action]", action.Method.Name)); - - foreach (var parameter in action.Parameters) - { - uriTemplate.WithParameter(parameter.Name, $"{{{parameter.Name}}}"); - } - - string url = HttpUtility.UrlDecode(uriTemplate.ExpandToString()).TrimEnd('/'); - - // if there is a parameter that is not a cancellation token and is not in the route template, then it is the body parameter - action.BodyParameter = action.Parameters.FirstOrDefault(p => p.Type.ToDisplayString() is not "System.Threading.CancellationToken" && url.Contains($"{{{p.Name}}}") is false); - action.CancellationTokenParameterName = action.Parameters.FirstOrDefault(p => p.Type.ToDisplayString() == "System.Threading.CancellationToken")?.Name; - action.Url = url; - } - - IControllers.Add(new IController - { - Actions = actions, - Name = controllerName, - ClassName = controllerSymbol.Name[1..], - Symbol = controllerSymbol, - Syntax = interfaceDeclarationSyntax - }); - } - } - } - catch (Exception exp) - { - throw new InvalidOperationException($"Error processing {syntaxNode.Node.SyntaxTree.FilePath}: {exp}", exp); - } - } -} diff --git a/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/IController.cs b/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/IController.cs deleted file mode 100644 index c6718b6cf4..0000000000 --- a/src/SourceGenerators/Bit.SourceGenerators/HttpClientProxy/IController.cs +++ /dev/null @@ -1,19 +0,0 @@ -using System.Collections.Generic; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis; - -namespace Bit.SourceGenerators; - -public class IController -{ - public string Name { get; set; } = default!; - - public string ClassName { get; set; } = default!; - - public ITypeSymbol Symbol { get; set; } = default!; - - public InterfaceDeclarationSyntax Syntax { get; set; } = default!; - - public List Actions { get; set; } = []; -} - diff --git a/src/SourceGenerators/Bit.SourceGenerators/IsExternalInit.cs b/src/SourceGenerators/Bit.SourceGenerators/IsExternalInit.cs new file mode 100644 index 0000000000..b22297d3ac --- /dev/null +++ b/src/SourceGenerators/Bit.SourceGenerators/IsExternalInit.cs @@ -0,0 +1,5 @@ +// Polyfill required for record types and init-only setters when targeting netstandard2.0. +namespace System.Runtime.CompilerServices +{ + internal static class IsExternalInit { } +}