[C#]SourceGenerator實戰: 對任意物件使用await吧!!!

2022-10-21 06:02:45

[C#]SourceGenerator實戰: 對任意物件使用await吧!!!

前言

本文記錄一次簡單的 SourceGenerator 實戰,最終實現可以在程式碼中 await 任意型別物件,僅供娛樂,請勿在生產環境中使用!!!

關鍵技術:

  • SourceGenerator

  • Await anything

    • C#中的 async/await 最終由編譯器編譯為狀態機,其核心邏輯在於 await 物件需要實現符合要求的 GetAwaiter 方法,這個方法可以是 拓展方法
    • 參見官方部落格 await anything;

那麼要實現對任何物件的 await 我們的思路大概如下:

  1. 找到所有的 await 語法
  2. 檢查 await 的物件是否有 GetAwaiter 方法
  3. 為沒有 GetAwaiter 方法的物件生成 GetAwaiter 拓展方法

得益於 SourceGenerator 豐富的分析API,我們可以很容易的辦到這件事


實現源生成器

GetAwaiter拓展方法模板

我們先來實現一個可以讓 TargetType 支援 await 的拓展方法類別範本:

using System.Runtime.CompilerServices;

namespace System.Threading.Tasks
{
    public static class GetAwaiterExtension_TargetTypeName
    {
        public static TaskAwaiterFor_TargetTypeName GetAwaiter(this TargetType value)
        {
            return new TaskAwaiterFor_TargetTypeName(value);
        }

        public readonly struct TaskAwaiterFor_TargetTypeName : ICriticalNotifyCompletion, INotifyCompletion
        {
            private readonly TargetType _value;

            public bool IsCompleted { get; } = true;

            public TaskAwaiterFor_TargetTypeName(TargetType value)
            {
                _value = value;
            }

            public TargetType GetResult()
            {
                return _value;
            }

            public void OnCompleted(Action continuation)
            {
                continuation();
            }

            public void UnsafeOnCompleted(Action continuation)
            {
                continuation();
            }
        }
    }
}
  • 將型別放在名稱空間 System.Threading.Tasks 下,可以在使用的時候不需要額外的名稱空間參照;
  • 由於我們已經有了需要返回的結果值,所以 AwaiterIsCompleted 始終為 trueGetResult 直接返回結果即可;

分析所有 await 語法,並篩選出需要為其生成 GetAwaiter 方法的型別

  1. 先建立一個 IncrementalGenerator
    [Generator(LanguageNames.CSharp)]
    public class GetAwaiterIncrementalGenerator : IIncrementalGenerator
    {
        public void Initialize(IncrementalGeneratorInitializationContext context)
        {
        }
    }
    
  2. Initialize 方法中篩選目標型別
    /// 使用語法提供器篩選出所有的 `await` 語法,並獲取其型別
    var symbolProvider = context.SyntaxProvider.CreateSyntaxProvider((node, _) => node is AwaitExpressionSyntax //直接判斷節點是否為 `AwaitExpressionSyntax` 即可篩選出所有 await 表示式
                                                                     , TransformAwaitExpressionSyntax)  //從 await 表示式中解析出其尚不支援 await 的物件型別符號
                                               .Where(m => m is not null)   //篩選掉無效的項
                                               .WithComparer(SymbolEqualityComparer.Default);   //使用預設的符號比較器進行比較
    
  3. 直接使用表示式語法不太方便處理,我們實現表示式語法到型別符號的轉換方法 TransformAwaitExpressionSyntax
    private static ITypeSymbol? TransformAwaitExpressionSyntax(GeneratorSyntaxContext generatorSyntaxContext, CancellationToken cancellationToken)
    {
        //經過篩選,到達此處的節點一定是 AwaitExpressionSyntax
        var awaitExpressionSyntax = (AwaitExpressionSyntax)generatorSyntaxContext.Node;
    
        //如果 await 表示式語法的 await 物件仍然是 AwaitExpressionSyntax ,那麼跳過此條記錄
        //類似 "await await await 1;" 我們直接忽略前兩個 await 表示式
        if (awaitExpressionSyntax.Expression is AwaitExpressionSyntax)
        {
            return null;
        }
    
        //使用 `SemanticModel` 可以分析出更具體的符號資訊,比如型別,方法等
        //直接使用其提供的 `GetAwaitExpressionInfo` 可以從表示式語法獲取 await 的詳細資訊
        var awaitExpressionInfo = generatorSyntaxContext.SemanticModel.GetAwaitExpressionInfo(awaitExpressionSyntax);
    
        //判斷分析結果中此表示式是否包含 `GetAwaiter` 方法,如果不包含,那麼我們需要為其生成
        if (awaitExpressionInfo.GetAwaiterMethod is null)
        {
            //`SemanticModel` 的 GetTypeInfo 方法可以獲取一個表示式的型別符號資訊
            //返回 await 物件的型別符號
            return generatorSyntaxContext.SemanticModel.GetTypeInfo(awaitExpressionSyntax.Expression).Type;
        }
    
        return null;
    }
    

為所有目標型別生成 GetAwaiter 拓展方法

由於只需要為相同型別生成一次 GetAwaiter 方法,所以我們需要將型別符號去重之後進行生成

  • 直接將上面的 symbolProvider 傳遞給 RegisterSourceOutput 方法的話,每次只會處理一個型別符號,我們無法去重
  • 呼叫 symbolProviderCollect 方法,可以將前面步驟篩選出的所有型別符號作為一個集合進行處理

所以註冊原始碼生成器可以這樣寫:

context.RegisterSourceOutput(symbolProvider.Collect(),  //將篩選的結果作為整體傳遞
                            (ctx, input) =>
                            {
                                //遍歷去重後的型別符號
                                foreach (var item in input.Distinct(SymbolEqualityComparer.Default))
                                {
                                    //為每個去重後的型別生成 `GetAwaiter` 拓展方法
                                }
                            });

接下來使用之前寫的拓展方法模板生成每個型別的 GetAwaiter 拓展方法即可:

//獲取型別符號的完整存取型別名
var fullyClassName = item!.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
//獲取不包含無效符號的類名
var className = NormalizeClassName(fullyClassName);
//替換模板中的型別預留位置為當前處理的目標型別
var code = templateCode.Replace("TargetTypeName", className)
                       .Replace("TargetType", fullyClassName);

//如果目標型別不是公開型別,那麼拓展方法也應該不公開
if (item.DeclaredAccessibility != Accessibility.Public)
{
    code = code.Replace("public static class", "internal static class");
}

//將生成的程式碼新增到編譯中
ctx.AddSource($"GetAwaiterFor_{className}.g.cs", code);
//將型別名稱中不能作為類名的符號替換為_
private static string NormalizeClassName(string value)
{
    return value.Replace('.', '_')
                .Replace('<', '_')
                .Replace('>', '_')
                .Replace(' ', '_')
                .Replace(',', '_')
                .Replace(':', '_');
}

到這裡我們就實現了所有的功能點,新建專案並參照分析器就可以 await 任何物件了,效果大概如下:

  • 程式碼 - AwaitAnyObject.zip
  • 也可以直接安裝 NuGet 包 AwaitAnyObject 進行遊玩