淺談 C++ 模板 & 泛化 (媽媽再也不用擔心我不會用 std::sort 了)

2022-12-25 21:00:19

基礎複習

先上個對 int 型別陣列的插入排序:

void insertionSort_01(int* seq, int firstIndex, int lastIndex) {
    for (int j = firstIndex + 1; j <= lastIndex; ++j) {
        int key = seq[j];
        int i = j - 1;
        while (i >= firstIndex && key < seq[i]) {
            seq[i + 1] = seq[i];
            --i;
        }
        seq[i + 1] = key;
    }
}
  • 提出問題: 如果想排 double 型別陣列怎麼辦?

可以過載一個 double 版本:

void insertionSort_01_b(double* seq, int firstIndex, int lastIndex) {
    ...
}

當然, 更好的方式是利用 C++ 的模板泛化元素型別:

template<class ElemType>
void insertionSort_02(ElemType* seq, int firstIndex, int lastIndex) {
    ...
}

步入正題

接著提出兩個問題:

  • 1 是否一定要求升序排列
  • 2 ElemType 物件是否一定能使用 operator<

為解決問題 1, 我們可以額外寫個降序排列版本:

template<class ElemType>
void insertionSort_02_b(ElemType* seq, int firstIndex, int lastIndex) {
    for (...) {
        ...
        // Change {<} to {>} when comparing {key} and {seq[i]}:
        while (i >= firstIndex && key > seq[i]) {
            ...
        }
        ...
    }
}

對於問題 2, 我們舉個例子.
現有:

struct MyStruct
{
    int aa;
    int bb;
};

MyStruct arr_MyStruct[4] = { {1,4},{3,1},{9,-1},{12,0} };

要求對 arr_MyStruct 中的元素以 MyStruct::aa 排序.
對於 C++ 新手來說, 這是一個比較難解決的問題, 也是問題 2 聚焦的關鍵.

對問題 1 的處理中, 我們將 "比較" 這個謂語 (predicate)operator< 替換為 opeartor>;
這給了我們一些提示: 是否可以像我們用模板來泛化元素型別那樣泛化謂語?

提出概念: 函數物件 (function object)
定義類 bad_greater:

// Omit the definition of class <MyStruct>.
struct bad_greater {
    // {operator()} should be defined as a const method,
    // in order to make it available to <const bad_greater> instances.
    bool operator()(const MyStruct& left, const MyStruct& right) const { return left.aa > right.aa; }
};

bad_greater 所建立的範例為函數物件, 可以參考以下使用案例:

// Omit the definition of class <MyStruct>.
MyStruct arr_MyStruct[4] = { {1,4},{3,1},{9,-1},{12,0} };
bad_greater compare;
std::cout << compare(arr_MyStruct[0], arr_MyStruct[1]) << std::endl;
// Use anonymous instance:
std::cout << bad_greater()(arr_MyStruct[0], arr_MyStruct[1]) << std::endl;

bad_greater 之所以 bad, 是因為其唯獨提供對類 MyStruct 範例的比較.
定義一個模板類 good_less 並對 MyStruct 偏特化以解決這個問題:

// Omit the definition of class <MyStruct>.
template<class T>
struct good_less
{
    bool operator()(const T& left, const T& right) const { return left < right; }
};

template<>
struct good_less<MyStruct>
{
    bool operator()(const MyStruct& left, const MyStruct& right) const { return left.aa < right.aa; }
};

有了函數物件, 我們可以泛化演演算法中的謂語:

template<class ElemType, class _Pred>
void insertionSort_03(ElemType* seq, int firstIndex, int lastIndex, const _Pred& compare) {
    for (...) {
        ...
        while (... && compare(key, seq[i])) {
            ...;
        }
        ...
    }
}

呼叫函數 insertionSort_03() 時, 我們要注意, 編譯器能直接根據傳入引數推斷模板的範例化型別; 因此無需提供額外的模板類引數:

// Omit the definition of class <MyStruct>.
// Omit the definition of class <good_less>.
// Omit the definition of class <good_greater>.

MyStruct arr_MyStruct[4] = { {1,4},{3,1},{9,-1},{12,0} };
// Ascending order:
insertionSort_03(arr_MyStruct, 0, 3, good_less<MyStruct>());
// Descending order:
insertionSort_03(arr_MyStruct, 0, 3, good_greater<MyStruct>());

// Also works for array with orther types:
double arr_double[4] = { 1,9.1,0.9,-3.1 };
insertionSort_03(arr_MyStruct, 0, 3, good_greater<double>());

std::sort() 的升降序排序

std::sort() 和我們的 insertionSort_03() 一樣泛化的謂語, 而且 STL 還提供了類 std::greaterstd::less 等用於定義函數物件.
升降序的使用方法參考以下程式碼:

#include <algorithm>
#include <functional>

double arr_double[4] = { 1,9.1,0.9,-3.1 };
// Ascending order:
std::sort(arr_double, arr_double + 4);
// Ascending order:
std::sort(arr_double, arr_double + 4, std::less<double>());
// Descending order:
std::sort(arr_double, arr_double + 4, std::greater<double>()));

你可能會問: 為什麼第一個例子不用和之前說的一樣, 傳入一個函數物件?
這沒什麼高深的, 在 C++14 之前, 其實只是額外提供了一個只有兩個引數的函數過載而已.
給個差不多的虛擬碼出來:

std::sort(seq_begin, seq_end){
    std::sort(seq_begin, seq_end, std::less());
}

C++14 之後在謂語類和 std::sort() 的定義上用了點小 trick, 下面給點啟發性的例子 (如果不感興趣, 你可以跳過這段):

template<class T = void>
struct less
{
    template<class T>
    bool operator()(const T& a, const T& b) const { return a < b; }
};
template<..., class _Pred = less<void>>
void sort(..., const _Pred& compare = {}) {
    ...
}

說簡單點, 就是 less 給了一個預設模板範例化型別 void; 而真正進行比較的 operator() 又是一個模板. 呼叫 sort 時, 不用考慮第三個引數 (函數對像) 具體是什麼型別, 反正 operator() 在比較時會自行範例化.
可以參考以下使用案例:

// Under C++ 14 (or later) standard.
#include <algorithm>
#include <functional>

double arr_double[4] = { 1,9.1,0.9,-3.1 };
std::sort(arr_double, arr_double + 4); // std::less<void>
std::sort(arr_double, arr_double + 4, std::less()); // std::less<void>
std::sort(arr_double, arr_double + 4, std::less<double>()); // std::less<double>

int arr_int[4] = { 1,3,4,0 };
std::sort(arr_double, arr_double + 4, std::less()); // std::less<int>

std::sort() 排其他型別範例

如果看懂了前面的內容, 想必你也能夠猜出來怎麼實現這個問題了.
注意, std::less 之類的謂語型別說到底就是結構體, 和我們上面實現的 good_less 沒啥區別. 所以如果我們還是要排序上文提到的 MyStruct 陣列:

// Omit the definition of class <MyStruct>.
// Omit the definition of class <good_less>.
// Omit the definition of class <good_greater>.
#include <algorithm>

MyStruct arr_MyStruct[4] = { {1,4},{3,1},{9,-1},{12,0} };
// Ascending order:
std::sort(arr_MyStruct, arr_MyStruct + 4, good_less<MyStruct>());
// Descending order:
std::sort(arr_MyStruct, arr_MyStruct + 4, good_greater<MyStruct>());

統一指標和迭代器

作為一個 STL 使用者, 難免會遇到指標與迭代器不統一的問題. 例如以下例子:

// Use pointer:
int arr_int[] = ...;
std::sort(arr_int, ...);

// Use iterator:
std::vector<int> arr_vector = ...;
std::sort(arr_vector.begin(), ...);

解決方式之一是統一泛化指標型別和迭代器型別, 這裡把它們都當作類 _RandIt .
我們還是以最開始的 insertionSort 為例, 給出示範程式碼.
需要注意的是, 通過迭代器和指標獲取元素型別 (用來定義 key )時, decltype 會保留解除參照 (dereference) 後留下的參照 & (也就是說 decltype(arr_int[0]) 得到的型別不是 int 而是 int& ); 因此需要呼叫 std::remove_reference 來刪除型別中的參照.

using index = long long;


template<class _RandIt, class _Pr = std::less<void>>
void insertionSort(_RandIt seq, index firstIndex, index lastIndex, const _Pr& comp = {}) {
    for (index j = firstIndex + 1; j <= lastIndex; ++j) {
        typename std::remove_reference<decltype(*seq)>::type key = seq[j];
        index i = j - 1;
        while (i >= firstIndex && comp(key, seq[i])) {
            seq[i + 1] = seq[i];
            --i;
        }
        seq[i + 1] = key;
    }
}

再給個歸併排序的程式碼吧! 就說到這裡 (計組完全沒學, 寄).

using index = long long;

template<class _RandIt, class _Pr>
void merge(_RandIt seq, index subAFirst, index subALast, index subBLast,
    auto MAX, auto MIN, const _Pr& comp) {
    auto END = comp(1, 2) ? MAX : MIN;

    size_t sizeSubA = subALast - subAFirst + 2;
    size_t sizeSubB = subBLast - subALast + 1;

    auto subA = new typename std::remove_reference<decltype(*seq)>::type[sizeSubA];
    std::copy(seq + subAFirst, seq + subALast + 1, subA);
    subA[sizeSubA - 1] = END;

    auto subB = new typename std::remove_reference<decltype(*seq)>::type[sizeSubB];
    std::copy(seq + subALast + 1, seq + subBLast + 1, subB);
    subB[sizeSubB - 1] = END;

    // Merge two subsequences to origin {seq[subAFirst : subBLast]}:
    for (index k = subAFirst, i = 0, j = 0; k <= subBLast; ++k) {
        if (i >= sizeSubA || j >= sizeSubB) return;
        // Merge:
        if (comp(subA[i], subB[j])) {
            seq[k] = subA[i]; ++i;
        } else {
            seq[k] = subB[j]; ++j;
        }
    }

    delete[] subA;
    delete[] subB;
}

template<class _RandIt, class _Pr = std::less<void>>
void mergeSort(_RandIt seq, index firstIndex, index lastIndex,
    auto MAX, auto MIN, const _Pr& comp = {}) {
    if (firstIndex >= lastIndex) return;
    index mid = (firstIndex + lastIndex) / 2;
    mergeSort(seq, firstIndex, mid, MAX, MIN, comp);
    mergeSort(seq, mid + 1, lastIndex, MAX, MIN, comp);
    merge(seq, firstIndex, mid, lastIndex, MAX, MIN, comp);
}