排序基礎與平行 Radix Sort 核心 (Sorting Foundations & Parallel Radix Sort)

重點總覽 (Overview)

項目 內容 備註
排序定義 輸出為 nondecreasing/nonincreasing 順序,且為 input 的 permutation 兩條件缺一不可
Stable vs Unstable stable 保留相同 key 的原始相對順序 多 key cascaded sort 必須 stable
Comparison vs Noncomparison comparison-based 下界 O(N log N);noncomparison 可更快 radix = noncomparison;merge = comparison
Radix sort 本質 依 radix(base)把 key 分配到 buckets,逐 digit 重複,每 iteration stable 1-bit radix → N-bit key 需 N 次 iteration
平行化粒度 iteration 之間序列;單一 iteration 內平行 host 每個 digit 呼叫一次 kernel
平行策略 one-thread-per-key:每 thread 算自己 key 的 destination index 後寫出 寫出是 scatter
核心工具 exclusive scan 計算 # ones before 並順帶得到 # ones total
Important

本章把 radix sort 拆成「單一 1-bit iteration kernel」。整體排序 = host 端對每個 bit 依序呼叫此 kernel。後續優化(coalescing、多 bit radix、coarsening)見 13-Sorting/02-Optimizing-Radix-Sort


排序基礎與分類 (Sorting Foundations & Classification)

任何 sorting algorithm 必須滿足兩個條件:

  1. 順序性:輸出為 nondecreasing(每個元素不小於前一個)或 nonincreasing(每個元素不大於前一個)。
  2. 保全性:輸出是 input 的 permutation(不增不減,只重排)。

常見以 (key, value) tuple 的形式排序,只依 key field 比較。例如以 income 為 key 把 [(30,150),(32,80),(22,45),(29,80)] 排成 nonincreasing → [(30,150),(32,80),(29,80),(22,45)]

Stable vs Unstable

類型 定義 用途
Stable 相同 key 時,保留原始出現順序 支援 multi-key cascaded sort(先排 secondary key,再排 primary key,第二次排序保留第一次的順序)
Unstable 不保證相同 key 的相對順序 單一 key 排序足夠時可用
Tip

上例中 (32,80) 必須排在 (29,80) 前面,因為原始輸入它在前。這正是 radix sort 每個 iteration 必須 stable 的原因——後一個 digit 的排序要靠前一個 digit 已建立的順序。

Comparison vs Noncomparison

類型 複雜度 代表 通用性
Comparison-based 不可能優於 O(N log N)(必須做最少次比較) merge sort、quicksort 換 comparison operator 即可適配任意 key 類型
Noncomparison-based 部分可優於 O(N log N) radix sort 通常只適用特定 key 類型(如 integer)
Warning

O(N log N) 是 comparison-based 的理論下界,不是 radix sort 的下界。radix sort 之所以能突破,正因為它不靠兩兩比較,而是靠 digit 分桶——代價是無法泛化到任意 key 類型。


Radix Sort 演算法 (Radix Sort Algorithm)

Radix sort = noncomparison-based,依 radix value(positional numeral system 的 base)把 key 分配到 buckets。若 key 有多個 digit,則對每個 digit 重複分配,每次 iteration 都stable(保留前一 iteration 在桶內的順序)。

1-bit iteration 範例(LSB 開始)

input :  3   5   4   1   7   2   6   0      (index 0..7)
LSB   :  1   1   0   1   1   0   0   0      (extract bit)
                 │
        ┌────────┴────────┐
        ▼                  ▼
   0 bucket           1 bucket
  (LSB == 0)         (LSB == 1)
   4  2  6  0         3  5  1  7            ← 桶內順序 = 原始順序 (stable!)
        │                  │
        └────────┬─────────┘
                 ▼
output:  4   2   6   0   3   5   1   7      (依 LSB 排好)
Important

桶內必須保序 (stability):第 2 個 iteration 以「output 變 input、考慮第 2 個 LSB」進行;由於前次順序被保留,output 此時已依低 2 個 bit排序(00 < 01 < 10 < 11)。依此類推,跑完所有 bit 後即完全排序。

iteration 間是序列依賴的

iter0(LSB) ─→ iter1 ─→ iter2 ─→ iter3(MSB)   每步輸入 = 前一步輸出
  [bit0]       [bit1]    [bit2]     [bit3]      ← 序列,不可平行

每個 iteration 都依賴前一個 iteration 的完整結果,故 iteration 之間必須序列執行。平行的機會在單一 iteration 之內


平行 Radix Sort 核心 (Parallel Radix Sort Kernel)

策略:one-thread-per-key——每個 thread 負責 input list 中一個 key,找出該 key 在 output list 的 destination index,再寫出去。

input :  3    5    4    1    7    2    6    0
         │    │    │    │    │    │    │    │     ← 每個 thread 負責下方的 key
         t0   t1   t2   t3   t4   t5   t6   t7
        └──block 0──┘   └──block 1──┘            (示意:每 block 4 threads)

Destination index 的推導

關鍵:thread 要知道自己的 key 落 0 bucket 還是 1 bucket,以及前面有幾個 key。令 i = key index(input 中的位置)。

落 0 bucket 的 key(destination of a 0):

destination = # zeros before
            = (# keys before) - (# ones before)
            = i - (# ones before)

落 1 bucket 的 key(destination of a 1,所有 0 必須排在所有 1 之前):

destination = (# zeros in total) + (# ones before)
            = (# keys in total - # ones in total) + (# ones before)
            = inputSize - (# ones in total) + (# ones before)

唯一非平凡的量是 # ones before(自己之前有幾個 key 落 1 bucket)——這正是 exclusive scan 的工作;而 # ones in total 可由 scan 的副產品取得。

用 exclusive scan 計算 # ones before

把每個 key 的 bit(0/1)當作 scan 的輸入,做 exclusive scan,結果每格 = 該位置之前所有 bit 的和 = # ones before:

i           :  0   1   2   3   4   5   6   7
bits (LSB)  :  1   1   0   1   1   0   0   0
                                                  ── exclusive scan ──▶
# ones before: 0   1   2   2   3   4   4   4      total ones = 4
                                                  (bits[N] 或 scan 副產品)

key=3 i=0 bit=1 → dst = N - tot + before = 8-4+0 = 4
key=5 i=1 bit=1 → dst = 8-4+1 = 5
key=4 i=2 bit=0 → dst = i - before     = 2-2 = 0
key=1 i=3 bit=1 → dst = 8-4+2 = 6
key=7 i=4 bit=1 → dst = 8-4+3 = 7
key=2 i=5 bit=0 → dst = 5-4 = 1
key=6 i=6 bit=0 → dst = 6-4 = 2
key=0 i=7 bit=0 → dst = 7-4 = 3
                                          output[dst] = key  (scatter)
output:  4   2   6   0   3   5   1   7

Kernel code(Fig. 13.4,單一 iteration)

// 對第 iter 個 bit 做一次 1-bit radix sort iteration
__global__ void radix_sort_iter(unsigned int* input, unsigned int* output,
                                unsigned int* bits, unsigned int N,
                                unsigned int iter) {
    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;   // 03 該 thread 的 key index
    unsigned int key, bit;
    if (i < N) {                                              // 04 邊界檢查
        key = input[i];                                       // 06 載入 key
        bit = (key >> iter) & 1;                              // 07 取出第 iter 個 bit
        bits[i] = bit;                                        // 08 寫入 bits 供 scan
    }
    exclusiveScan(bits, N);                                   // 10 全 grid 協作 exclusive scan
    if (i < N) {
        unsigned int numOnesBefore = bits[i];                // 12 # ones before
        unsigned int numOnesTotal  = bits[N];                // 13 # ones total(scan 副產品)
        unsigned int dst = (bit == 0) ? (i - numOnesBefore)              // 14 落 0 bucket
                                      : (N - numOnesTotal + numOnesBefore);// 15 落 1 bucket
        output[dst] = key;                                   // 16 scatter 寫到 destination
    }
}
Warning

exclusiveScan 必須放在邊界檢查之外(line 10 在 if (i<N) 之外)。原因:scan 過程中 thread 可能要做 barrier synchronization,必須所有 thread 都 active,否則 inactive thread 缺席會造成 barrier 死結 / 結果錯誤。

Tip

要在整個 grid 同步做 scan,有兩條路:

  1. 用類似 11-Prefix-Sum-Scan/03-Arbitrary-Length-and-Single-Pass-Scansingle-pass / grid-wide 技術,維持單一 kernel launch。
  2. 拆成 3 個 kernel:scan 前處理 → 獨立 scan kernel → scan 後處理。代價是每個 iteration 變成 3 次 grid launch(而非 1 次)。

此基礎核心的缺陷(承接後續優化)

output[dst] = keyscatter 寫出:consecutive thread index 不一定寫到 consecutive memory location(t0→0桶、t1→1桶、t2→0桶…交錯),導致 memory coalescing 很差,每個 warp 需發多次 memory request。解法(block-level local sort + shared memory)見 13-Sorting/02-Optimizing-Radix-Sort


考試/面試重點 (Exam / Test Patterns)

情境 / 關鍵字 答案 / 技巧
「排序的兩個必要條件」 (1) nondecreasing/nonincreasing 順序;(2) 是 input 的 permutation
「stable sort 有何用」 保留相同 key 原始順序 → 支援 multi-key cascaded sort(先排 secondary 再排 primary)
「為何 radix 每 iteration 要 stable」 後一 digit 的排序靠前一 digit 已建立的桶內順序;不 stable 則整體錯誤
「comparison-based 複雜度下界」 O(N log N);radix(noncomparison)可更快但只適用特定 key 類型
「N-bit key、1-bit radix 需幾次 iteration」 N 次;更大 radix(r-bit)減少 iteration 數
「iteration 之間能否平行」 不能(序列依賴前一步完整結果);平行在單一 iteration 內
「destination of a 0」 i - (# ones before)
「destination of a 1」 inputSize - (# ones total) + (# ones before)
# ones before 怎麼算」 對 bit 陣列做 exclusive scan;# ones total 為其副產品
「為何 scan 在邊界檢查外」 scan 內含 barrier sync,需所有 thread active
「LSD vs MSD radix sort」 本章為 LSD(LSB→MSB);MSD 從高位分桶遞迴,適合超大序列
「為何基礎 kernel 寫出慢」 寫出是 scatter,consecutive thread 寫非連續位址 → uncoalesced