通過宏封裝實現std::format編譯期檢查引數數量是否一致

2022-08-29 18:02:33

背景

std::format在傳引數量少於格式串所需引數數量時,會丟擲異常。而在大部分的應用場景下,引數數量不一致提供編譯報錯更加合適,可以促進我們更早發現問題並進行改正。

最終效果

// 測試輸出介面。
template <typename... T>
void Print(const std::string& _Fmt, const T&... _Args)
{
    cout << std::vformat(_Fmt, std::make_format_args(_Args...)) << endl;
}

// 封裝宏,實現引數數量一致的檢查
#define PRINT(fmt, ...) \
    do { static_assert(GetFormatStringArgsNum(fmt) == decltype(VariableArgsNumHelper(__VA_ARGS__))::value, "Invalid format string or mismatched number of arguments"); Print(fmt, __VA_ARGS__); } while(0)

int main()
{
    PRINT("{}", "hello");
    PRINT("{} {}", "hello");

    return 0;
}

上例程式碼中,使用PRINT宏封裝了Print函數,後續使用PRINT進行控制檯輸出,如果出現引數數量不一致,將產生編譯報錯:Invalid format string or mismatched number of arguments

所用技術

  1. 靜態斷言: static_assert

  2. 格式串引數數量獲取: GetFormatStringArgsNum,該介面宣告為constexpr,從而獲得編譯期執行的能力。其實現大致為遍歷字串,檢查其中{}的數量。

  3. 傳引數量的獲取: 由於使用宏進行封裝,最後其實就是需要獲得__VA_ARGS__中附帶了幾個引數,網上可以搜到各種解決方案,這裡採用的是宣告一個模板函數,模板函數返回integral_constant結構體,其對不同的引數數量,自動生成不同的結構體型別,之後使用decltype(VariableArgsNumHelper(__VA_ARGS__))獲得返回值型別,並從返回值型別中獲得代表引數數量的常數值,由於執行期用不到該函數,因此只提供宣告,不提供實現。

整體程式碼

#include <iostream>
#include <string>
#include <format>
using namespace std;

constexpr int GetFormatStringArgsNum(const std::string& fmt)
{
	enum STATE
	{
		NORMAL,			// 正在解析普通串
		REPLACEMENT,	// 正在解析大括號中的內容
	};

	// 按標準規定,格式串中要麼都指定引數編號,要麼都不指定
	// 原文:
	// The arg-ids in a format string must all be present or all be omitted. 
	// Mixing manual and automatic indexing is an error.
	enum RULE
	{
		UNKNOWN,		// 格式串規則
		SPECIFIEDID,	// 指定編號,如{0}
		UNSPECIFIEDID,	// 不指定編號,如{}
	};

	// 指定引數編號的最大值
	const int MAX_ARGS_NUM = 10000;
	// 初始狀態
	STATE state = NORMAL;
	// 初始規則
	RULE rule = UNKNOWN;
	// 當前引數編號
	int nIndex = -1;
	// 引數數量
	int nArgsNum = 0;
	for (int i = 0; i < fmt.size(); ++i)
	{
		switch (state)
		{
		case NORMAL:
		{
			// 普通串解析時,遇到左大括號或右大括號,才有可能改變狀態
			if (fmt[i] == '{')
			{
				if (i + 1 < fmt.size() && fmt[i + 1] == '{')
				{
					// 遇到 {{,則將他們視為普通字元
					++i;
				}
				else
				{
					// 進入替換串狀態
					state = REPLACEMENT;
				}
			}
			else if (fmt[i] == '}')
			{
				++i;
				if (i >= fmt.size() || fmt[i] != '}')
				{
					// 普通串解析狀態,遇上右大括號時,只有當接下來也是右大括號時,才屬於合法串
					return -1;
				}
			}
		}
		break;
		case REPLACEMENT:
		{
			// 替換串狀態下,正常只會遇到右大括號、數位、冒號,其他符號均為錯誤
			if (fmt[i] == '}')
			{
				// 遇到右大括號,則進入普通串解析狀態,這裡不考慮}},正常{} 中間不應該出現}
				state = NORMAL;

				// 如果之前某個{} 已經指定引數編號,則所有引數都應該指定編號
				if (rule == SPECIFIEDID)
				{
					// 如果這個{} 不指定編號,則視為非法格式串
					if (nIndex == -1)
					{
						return -1;
					}
					// 在指定編號的情況下,可變引數的數量至少要比編號大1
					nArgsNum = std::max(nArgsNum, nIndex + 1);
					// 重置當前編號
					nIndex = -1;
				}
				else
				{
					// 如果當前規則未明或者當前規則為不指定編號,則引數數量進行自增。
					state = NORMAL;
					rule = UNSPECIFIEDID;
					++nArgsNum;
				}
			}
			else if (fmt[i] >= '0' && fmt[i] <= '9')
			{
				// 遇到數位,說明指定了引數編號
				if (rule == UNSPECIFIEDID)
				{
					// 如果當前規則已明確為不指定編號,則視為非法格式串
					return -1;
				}
				else
				{
					// 否則,將當前規則改為指定編號,並維護當前編號
					rule = SPECIFIEDID;
					if (nIndex == -1)
					{
						nIndex = 0;
					}

					nIndex = nIndex * 10 + (fmt[i] - '0');
					if (nIndex >= MAX_ARGS_NUM)
					{
						// 當前編號大於最大上限,則直接視為非法格式串
						return -1;
					}
				}
			}
			else if (fmt[i] == ':')
			{
				// 遇到冒號,說明接下來是格式串規則,直接跳過
				for (; i + 1 < fmt.size() && fmt[i + 1] != '}'; ++i)
				{
					;
				}
			}
			else
			{
				// 解析替換串時,遇上其他字元,均將格式串視為非法。
				return -1;
			}
		}
		break;
		}
	}

	// 最終狀態必須為普通串解析狀態。
	return state == NORMAL ? nArgsNum : -1;
}

// 可變引數數量輔助器
template <typename ... Args>
std::integral_constant<std::size_t, sizeof...(Args)> VariableArgsNumHelper(const Args  & ...);

// 測試輸出介面。
template <typename... T>
void Print(const std::string& _Fmt, const T&... _Args)
{
	cout << std::vformat(_Fmt, std::make_format_args(_Args...)) << endl;
}

// 封裝宏,實現引數數量一致的檢查
#define PRINT(fmt, ...) \
    do { static_assert(GetFormatStringArgsNum(fmt) == decltype(VariableArgsNumHelper(__VA_ARGS__))::value, "Invalid format string or mismatched number of arguments"); Print(fmt, __VA_ARGS__); } while(0)


int main()
{
	PRINT("{} {}", "hello");

	return 0;
}