.NET Task 揭祕(3)async 與 AsyncMethodBuilder

2023-03-16 06:00:40

前言

本文為系列部落格

  1. 什麼是 Task
  2. Task 的回撥執行與 await
  3. async 與 AsyncMethodBuilder(本文)
  4. 總結與常見誤區(TODO)

上文我們學習了 await 這個語法糖背後的實現,瞭解了 await 這個關鍵詞是如何去等待 Task 的完成並獲取 Task 執行結果。並且我們還實現了一個簡單的 awaitable 型別,它可以讓我們自定義 await 的行為。

class FooAwaitable<TResult>
{
    // 回撥,簡化起見,未將其包裹到 TaskContinuation 這樣的容器裡
    private Action _continuation;

    private TResult _result;
    
    private Exception _exception;

    private volatile bool _completed;

    public bool IsCompleted => _completed;

    // Awaitable 中的關鍵部分,提供 GetAwaiter 方法
    public FooAwaiter<TResult> GetAwaiter() => new FooAwaiter<TResult>(this);

    public void Run(Func<TResult> func)
    {
        new Thread(() =>
        {
            var result = func();
            TrySetResult(result);
        })
        {
            IsBackground = true
        }.Start();
    }

    private bool AddFooContinuation(Action action)
    {
        if (_completed)
        {
            return false;
        }
        _continuation += action;
        return true;
    }

    internal void TrySetResult(TResult result)
    {
        _result = result;
        _completed = true;
        _continuation?.Invoke();
    }
    
    internal void TrySetException(Exception exception)
    {
        _exception = exception;
        _completed = true;
        _continuation?.Invoke();
    }

    // 1 實現 ICriticalNotifyCompletion
    public struct FooAwaiter<TResult> : ICriticalNotifyCompletion
    {
        private readonly FooAwaitable<TResult> _fooAwaitable;
        
        // 2 實現 IsCompleted 屬性
        public bool IsCompleted => _fooAwaitable.IsCompleted;

        public FooAwaiter(FooAwaitable<TResult> fooAwaitable)
        {
            _fooAwaitable = fooAwaitable;
        }

        public void OnCompleted(Action continuation)
        {
            Console.WriteLine("FooAwaiter.OnCompleted");
            if (_fooAwaitable.AddFooContinuation(continuation))
            {
                Console.WriteLine("FooAwaiter.OnCompleted: added continuation");
            }
            else
            {
                Console.WriteLine("FooAwaiter.OnCompleted: already completed, invoking continuation");
                continuation();
            }
        }

        public void UnsafeOnCompleted(Action continuation)
        {
            Console.WriteLine("FooAwaiter.UnsafeOnCompleted");
            if (_fooAwaitable.AddFooContinuation(continuation))
            {
                Console.WriteLine("FooAwaiter.UnsafeOnCompleted: added continuation");
            }
            else
            {
                Console.WriteLine("FooAwaiter.UnsafeOnCompleted: already completed, invoking continuation");
                continuation();
            }
        }

        // 3. 實現 GetResult 方法
        public TResult GetResult()
        {
            if (_fooAwaitable._exception != null)
            {
                // 4. 如果 awaitable 中有異常,則丟擲
                throw _fooAwaitable._exception;
            }
            Console.WriteLine("FooAwaiter.GetResult");
            return _fooAwaitable._result;
        }
    }
}

如果在一個方法中使用了 await,那麼這個方法就必須新增 async 修飾符。並且這個方法的返回型別通常是 Task 或者 其它 runtime 裡定義的 awaitable 型別。

int foo = await FooAsync();
Console.WriteLine(foo); // 1

async Task<int> FooAsync()
{
    await Task.Delay(1000);
    return 1;
}

問題1: 上面的程式碼中,FooAsync 方法是一個非同步方法,它的返回型別是 Task。但程式碼中的 await FooAsync() 並不會返回 Task,而是返回 int。這是為什麼呢?

如果我們把 FooAsync 的返回值改成我們自己實現的 awaitable 型別,編譯器會報錯:

問題2: 明明我們可以在 FooAwaitable 範例上使用 await 關鍵詞,為什麼把它作為 FooAsync 的返回型別就會報錯呢?且提示它不是一個 task-like 型別?

實際上我們在上篇文章實現的 awaitable 型別 FooAwaitable,只是支援了 await 關鍵詞,並不是一個完整的 task-like 型別。

而上面兩個問題的答案就是本文要講的內容:AsyncMethodBuilder

AsyncMethodBuilder 介紹

AsyncMethodBuilder 是狀態機的重要組成部分

參照上一篇文章介紹狀態機的程式碼:

class Program
{
    static async Task Main(string[] args)
    {
        var a = 1;
        Console.WriteLine(await FooAsync(a));
    }

    static async Task<int> FooAsync(int a)
    {
        int b = 2;
        int c = await BarAsync();
        return a + b + c;
    }

    static async Task<int> BarAsync()
    {
        await Task.Delay(100);
        return 3;
    }
}

由 FooAsync 編譯成的 IL 程式碼經整理後的等效 C# 程式碼如下:

using System;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;

class Program
{
    static async Task Main(string[] args)
    {
        var a = 1;
        Console.WriteLine(await FooAsync(a));
    }

    static Task<int> FooAsync(int a)
    {
        var stateMachine = new FooStateMachine
        {
            _asyncTaskMethodBuilder = AsyncTaskMethodBuilder<int>.Create(),
    
            _state = -1, // 初始化狀態
            _a = a // 將實參拷貝到狀態機欄位
        };
        // 開始執行狀態機
        stateMachine._asyncTaskMethodBuilder.Start(ref stateMachine);
        return stateMachine._asyncTaskMethodBuilder.Task;
    }

    static async Task<int> BarAsync()
    {
        await Task.Delay(100);
        return 3;
    }

    public class FooStateMachine : IAsyncStateMachine
    {
        // 方法的引數和區域性變數被編譯會欄位
        public int _a;
        public AsyncTaskMethodBuilder<int> _asyncTaskMethodBuilder;
        private int _b;

        private int _c;

        // -1: 初始化狀態
        // 0: 等到 Task 執行完成
        // -2: 狀態機執行完成
        public int _state;

        private TaskAwaiter<int> _taskAwaiter;

        public void MoveNext()
        {
            var result = 0;
            TaskAwaiter<int> taskAwaiter;
            try
            {
                // 狀態不是0,代表 Task 未完成
                if (_state != 0)
                {
                    // 初始化區域性變數
                    _b = 2;

                    taskAwaiter = Program.BarAsync().GetAwaiter();
                    if (!taskAwaiter.IsCompleted)
                    {
                        // state: -1 => 0,非同步等待 Task 完成
                        _state = 0;
                        _taskAwaiter = taskAwaiter;
                        var stateMachine = this;
                        // 內部會呼叫 將 stateMachine.MoveNext 註冊為 Task 的回撥
                        _asyncTaskMethodBuilder.AwaitUnsafeOnCompleted(ref taskAwaiter, ref stateMachine);
                        return;
                    }
                }
                else
                {
                    taskAwaiter = _taskAwaiter;
                    // TaskAwaiter 是個結構體,這邊相當於是個清空 _taskAwaiter 欄位的操作
                    _taskAwaiter = new TaskAwaiter<int>();
                    // state: 0 => -1,狀態機恢復到初始化狀態
                    _state = -1;
                }

                _c = taskAwaiter.GetResult();
                result = _a + _b + _c;
            }
            catch (Exception e)
            {
                // state: any => -2,狀態機執行完成
                _state = -2;
                _asyncTaskMethodBuilder.SetException(e);
                return;
            }

            // state: -1 => -2,狀態機執行完成
            _state = -2;
            // 將 result 設定為 FooAsync 方法的返回值
            _asyncTaskMethodBuilder.SetResult(result);
        }

        public void SetStateMachine(IAsyncStateMachine stateMachine)
        {
        }
    }
}

在編譯器生成的狀態機類中,我們可以看到一個名為 _asyncTaskMethodBuilder 的欄位,它的型別是 AsyncTaskMethodBuilder<int>。
這個 AsyncTaskMethodBuilder 就是 Task所繫結的 AsyncMethodBuilder。

AsyncMethodBuilder 的結構

以 AsyncTaskMethodBuilder<TResult> 為例,我們來看下 AsyncMethodBuilder 的結構:

public struct AsyncTaskMethodBuilder<TResult>
{
    // 儲存最後作為返回值的 Task
    private Task<TResult>? m_task;

    // 建立一個 AsyncTaskMethodBuilder
    public static AsyncTaskMethodBuilder<TResult> Create() => default;

    // 開始執行 AsyncTaskMethodBuilder 及其繫結的狀態機 
    public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine =>
        AsyncMethodBuilderCore.Start(ref stateMachine);

    // 繫結狀態機,但編譯器的編譯結果不會呼叫
    public void SetStateMachine(IAsyncStateMachine stateMachine) =>
        AsyncMethodBuilderCore.SetStateMachine(stateMachine, m_task);

    // 將狀態機的 MoveNext 方法註冊為 async方法 內 await 的 Task 的回撥
    public void AwaitOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : INotifyCompletion
        where TStateMachine : IAsyncStateMachine =>
        AwaitOnCompleted(ref awaiter, ref stateMachine, ref m_task);

    // 同上,參考前一篇文章講 UnsafeOnCompleted 和 OnCompleted 的區別
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : ICriticalNotifyCompletion
        where TStateMachine : IAsyncStateMachine =>
        AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine, ref m_task);

    public Task<TResult> Task
    {
        get => m_task ?? InitializeTaskAsPromise();
    }

    public void SetResult(TResult result)
    {
        if (m_task is null)
        {
            m_task = Threading.Tasks.Task.FromResult(result);
        }
        else
        {
            SetExistingTaskResult(m_task, result);
        }
    }

    public void SetException(Exception exception) => SetException(exception, ref m_task);
}

非泛型的 Task 對應的 AsyncMethodBuilder 是 AsyncTaskMethodBuilder,它的結構與泛型的 AsyncTaskMethodBuilder<TResult> 類似,但因為最終返回的 Task 沒有執行結果,它的 SetResult 只是為了標記 Task 的完成狀態並觸發 Task 的回撥。

public struct AsyncTaskMethodBuilder
{
    private Task<VoidTaskResult>? m_task;

    public static AsyncTaskMethodBuilder Create() => default;


    public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine =>
        AsyncMethodBuilderCore.Start(ref stateMachine);

    public void SetStateMachine(IAsyncStateMachine stateMachine) =>
        AsyncMethodBuilderCore.SetStateMachine(stateMachine, task: null);

    public void AwaitOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : INotifyCompletion
        where TStateMachine : IAsyncStateMachine =>
        AsyncTaskMethodBuilder<VoidTaskResult>.AwaitOnCompleted(ref awaiter, ref stateMachine, ref m_task);

    public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : ICriticalNotifyCompletion
        where TStateMachine : IAsyncStateMachine =>
        AsyncTaskMethodBuilder<VoidTaskResult>.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine, ref m_task);

    public Task Task
    {
        get => m_task ?? InitializeTaskAsPromise();
    }

    public void SetResult()
    {
        if (m_task is null)
        {
            m_task = Task.s_cachedCompleted;
        }
        else
        {
            AsyncTaskMethodBuilder<VoidTaskResult>.SetExistingTaskResult(m_task, default!);
        }
    }

    public void SetException(Exception exception) =>
        AsyncTaskMethodBuilder<VoidTaskResult>.SetException(exception, ref m_task);
}            

AsyncMethodBuilder 功能分析

AsyncTaskMethodBuilder 在 FooAsync 方法的執行過程中,起到了以下作用:

  1. 對內:關聯狀態機和狀態機執行的上下文,管理狀態機的生命週期。
  2. 對外:構建一個 Task 物件,作為非同步方法的返回值,並會觸發該 Task 執行的完成或異常。

為了方便說明,下文我們將 FooAsync 方法返回的 Task 稱為 FooTask,BarAsync 方法返回的 Task 稱為 BarTask。

對狀態機的生命週期進行管理

狀態機通過 _asyncTaskMethodBuilder.Start 方法來啟動且其 MoveNext 方式是通過 _asyncTaskMethodBuilder.AwaitUnsafeOnCompleted 方法來註冊為 BarTask 的回撥的。

對 async 方法的返回值進行包裝

_asyncTaskMethodBuilder 是用來構建一個 Task 物件,_asyncTaskMethodBuilder 的 Task 屬性就是 FooAsync 方法返回的 FooTask。通過 _asyncTaskMethodBuilder 的 SetResult 方法,我們可以設定 FooTask 的執行結果, 通過 SetException 方法,我們可以設定 FooTask 的異常。

小結

一個 AsyncMethodBuilder 是由下面幾個部分組成的:

  1. 一個 Task 物件,作為非同步方法的返回值。
  2. Create 方法,用來建立 AsyncMethodBuilder。
  3. Start 方法,用來啟動狀態機。
  4. AwaitOnCompleted/AwaitUnsafeOnCompleted 方法,用來將狀態機的 MoveNext 方法註冊為 async方法 內 await 的 Task 的回撥。
  5. SetResult/SetException 方法,用來標記 Task 的完成狀態並觸發 Task 的回撥。
  6. SetStateMachine 方法,用來關聯狀態機,不常用,編譯結果也不會呼叫。

async void

為了讓 async 方法適配傳統的事件回撥,C# 引入了 async void 的概念。

var foo = new Foo();
foo.OnSayHello += FooAsync;
foo.SayHello();

Console.ReadLine();

async void FooAsync(object sender, EventArgs e)
{
    var args = e as SayHelloEventArgs;
    await Task.Delay(1000);
    Console.WriteLine(args.Message);
}

class Foo
{
    public event EventHandler OnSayHello;

    public void SayHello()
    {
        OnSayHello.Invoke(this, new SayHelloEventArgs { Message = "Hello" });
    }
}

class SayHelloEventArgs : EventArgs
{
    public string Message { get; set; }
}

async void 也有一個對應的 AsyncVoidMethodBuilder。

    public struct AsyncVoidMethodBuilder
    {
        // AsyncVoidMethodBuilder 是對 AsyncTaskMethodBuilder 的封裝
        private AsyncTaskMethodBuilder _builder;

        public static AsyncVoidMethodBuilder Create()
        {
            // ...
        }

        public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine =>
            AsyncMethodBuilderCore.Start(ref stateMachine);

        public void SetStateMachine(IAsyncStateMachine stateMachine) =>
            _builder.SetStateMachine(stateMachine);

        public void AwaitOnCompleted<TAwaiter, TStateMachine>(
            ref TAwaiter awaiter, ref TStateMachine stateMachine)
            where TAwaiter : INotifyCompletion
            where TStateMachine : IAsyncStateMachine =>
            _builder.AwaitOnCompleted(ref awaiter, ref stateMachine);

        public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
            ref TAwaiter awaiter, ref TStateMachine stateMachine)
            where TAwaiter : ICriticalNotifyCompletion
            where TStateMachine : IAsyncStateMachine =>
            _builder.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine);

        public void SetResult()
        {
            // 僅僅是做 runtime 的一些狀態標記
        }

        public void SetException(Exception exception)
        {
            // 這個異常只能通過 TaskScheduler.UnobservedTaskException 事件來捕獲
        }

        // 因為沒有返回值,這個 Task 不對外暴露
        private Task Task => _builder.Task;
    }

自定義 AsyncMethodBuilder

自定義一個 AsyncMethodBuilder,不需要實現任意介面,只需要實現上面說的那 6 個主要組成部分,編譯器就能夠正常編譯。

awaitable 繫結 AsyncMethodBuilder 的方式有兩種:

  1. 在 awaitable 型別上新增 AsyncMethodBuilderAttribute 來繫結 AsyncMethodBuilder。
  2. 在 async 方法上新增 AsyncMethodBuilderAttribute 來繫結 AsyncMethodBuilder,用來覆蓋 awaitable 型別上的 AsyncMethodBuilderAttribute(前提是 awaitable 型別上有 AsyncMethodBuilderAttribute)。
struct FooAsyncMethodBuilder<TResult>
{
    private FooAwaitable<TResult> _awaitable;

    // 1. 定義 Task 屬性
    public FooAwaitable<TResult> Task
    {
        get
        {
            Console.WriteLine("FooAsyncMethodBuilder.Task");
            return _awaitable;
        }
    }
    
    // 2. 定義 Create 方法
    public static FooAsyncMethodBuilder<TResult> Create()
    {
        Console.WriteLine("FooAsyncMethodBuilder.Create");
        var awaitable = new FooAwaitable<TResult>();
        var builder = new FooAsyncMethodBuilder<TResult>
        {
            _awaitable = awaitable,
        };
        return builder;
    }
    
    // 3. 定義 Start 方法
    public void Start<TStateMachine>(ref TStateMachine stateMachine)
        where TStateMachine : IAsyncStateMachine
    {
        Console.WriteLine("FooAsyncMethodBuilder.Start");
        stateMachine.MoveNext();
    }

    
    // 4. 定義 AwaitOnCompleted/AwaitUnsafeOnCompleted 方法
    
    // 如果 awaiter 實現了 INotifyCompletion 介面,就呼叫 AwaitOnCompleted 方法
    public void AwaitOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : INotifyCompletion
        where TStateMachine : IAsyncStateMachine
    {
        Console.WriteLine("FooAsyncMethodBuilder.AwaitOnCompleted");
        awaiter.OnCompleted(stateMachine.MoveNext);
    }

    [SecuritySafeCritical]
    public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter,
        ref TStateMachine stateMachine)
        where TAwaiter : ICriticalNotifyCompletion
        where TStateMachine : IAsyncStateMachine
    {
        Console.WriteLine("FooAsyncMethodBuilder.AwaitUnsafeOnCompleted");
        awaiter.UnsafeOnCompleted(stateMachine.MoveNext);
    }

    // 5. 定義 SetResult/SetException 方法
    public void SetResult(TResult result)
    {
        Console.WriteLine("FooAsyncMethodBuilder.SetResult");
        _awaitable.TrySetResult(result);
    }
    
    public void SetException(Exception exception)
    {
        Console.WriteLine("FooAsyncMethodBuilder.SetException");
        _awaitable.TrySetException(exception);
    }
    
    // 6. 定義 SetStateMachine 方法,雖然編譯器不會呼叫,但是編譯器要求必須有這個方法
    public void SetStateMachine(IAsyncStateMachine stateMachine)
    {
        Console.WriteLine("FooAsyncMethodBuilder.SetStateMachine");
    }
}

// 7. 通過 AsyncMethodBuilderAttribute 繫結 FooAsyncMethodBuilder
[AsyncMethodBuilder(typeof(FooAsyncMethodBuilder<>))]
class FooAwaitable<TResult>
{
    // ...
}
Console.WriteLine("await Foo1Async()");
int foo1= await Foo1Async();
Console.WriteLine("Foo1Async() result: " + foo1);
Console.WriteLine();

Console.WriteLine("await Foo2Async()");

int foo2 = await Foo2Async();
Console.WriteLine("Foo2Async() result: " + foo2);
Console.WriteLine();

Console.WriteLine("await FooExceptionAsync()");
try
{
    await FooExceptionAsync();
}
catch (Exception e)
{
    Console.WriteLine(e.Message);
}

async FooAwaitable<int> Foo1Async()
{
    await Task.Delay(1000);
    return 1;
}

// 覆蓋預設的 AsyncMethodBuilder,使用 FooAsyncMethodBuilder2
// 本文省略了 FooAsyncMethodBuilder2 的定義,可以參考上面的 FooAsyncMethodBuilder
[AsyncMethodBuilder(typeof(FooAsyncMethodBuilder2<>))]
async FooAwaitable<int> Foo2Async()
{
    await Task.Delay(1000);
    return 2;
}

執行結果:

await Foo1Async()
FooAsyncMethodBuilder.Create
FooAsyncMethodBuilder.Start
FooAsyncMethodBuilder.AwaitUnsafeOnCompleted
FooAsyncMethodBuilder.Task
FooAwaiter.UnsafeOnCompleted
FooAwaiter.UnsafeOnCompleted: added continuation
FooAsyncMethodBuilder.SetResult
FooAwaiter.GetResult
Foo1Async() result: 1

await Foo2Async()
FooAsyncMethodBuilder2.Create
FooAsyncMethodBuilder2.Start
FooAsyncMethodBuilder2.AwaitUnsafeOnCompleted
FooAsyncMethodBuilder2.Task
FooAwaiter.UnsafeOnCompleted
FooAwaiter.UnsafeOnCompleted: added continuation
FooAsyncMethodBuilder2.SetResult
FooAwaiter.GetResult
Foo2Async() result: 2

await FooExceptionAsync()
FooAsyncMethodBuilder.Create
FooAsyncMethodBuilder.Start
FooAsyncMethodBuilder.AwaitUnsafeOnCompleted
FooAsyncMethodBuilder.Task
FooAwaiter.UnsafeOnCompleted
FooAwaiter.UnsafeOnCompleted: added continuation
FooAsyncMethodBuilder.SetException
Exception from FooExceptionAsync

在方法上新增 AsyncMethodBuilderAttribute 的功能是後來才新增的,通過這個功能,可以覆蓋 awaitable 型別上的 AsyncMethodBuilderAttribute,以便進行效能優化。例如 .NET 6 開始提供的 PoolingAsyncValueTaskMethodBuilder,對原始的 AsyncValueTaskMethodBuilder 進行了池化處理,可以通過在方法上新增 AsyncMethodBuilderAttribute 來使用。

歡迎關注個人技術公眾號