Co-rank Function 實作與 Basic Parallel Merge Kernel

重點總覽 (Overview)

項目 內容 關鍵點
co-rank function 輸入 rank k、兩個 sorted 陣列 → 回傳 co-rank i j = k - i 由 caller 推導
核心不變式 (invariant) i + j = k 全程維持 任何 marker 調整都不能破壞
搜尋演算法 binary search(也可 higher-radix) O(log₂N),N = max(m, n)
接受條件 A[i-1] <= B[j]B[j-1] < A[i] <= / < 不對稱 → 維持 stability(A 優先)
搜尋上界 i = min(k, m)j = k - i i 不可超過 A 大小或 k
搜尋下界 i_low = max(0, k-n)j_low = max(0, k-m) 縮小搜尋範圍、加速
basic kernel 每 thread 呼叫 co_rank 兩次(k_curr、k_next)→ merge_sequential 純 global memory
弱點 merge 與 co_rank 的存取未 coalesced 浪費 memory bandwidth → 下一節 tiling 解決
Important

Merge 與先前所有 pattern 最大的不同:每個 thread 的輸入範圍無法由 index 計算得出,而是依賴實際資料值 (dynamic input data identification)。co-rank function 就是用來「反查」輸入範圍的工具。


Co-rank Function 定義與不變式 (Co-rank Function & the i+j=k Invariant)

// 函式簽章:只回傳 i,caller 自行算 j = k - i
int co_rank(int k, int* A, int m, int* B, int n);
Output C (m+n=9), 以 rank k 切給 3 threads (各 3 個):
          k0=0      k1=3      k2=6      k3=9
   C: [ . . . | . . . | . . . ]
       thread0   thread1  thread2

co_rank(k) 把每個邊界 rank 映射到 (i, j),永遠滿足 i+j=k:
   A=[1 7 8 9 10]   ->  i0=0   i1=2   i2=5  (=m)
   B=[7 10 10 12]   ->  j0=0   j1=1   j2=1

thread t 合併 A[i_t .. i_{t+1}-1] 與 B[j_t .. j_{t+1}-1]  ->  C[k_t .. k_{t+1}-1]
   ex. thread1: A[2..4]=(8,9,10), B[1..0]=(空)  ->  C[3..5]=(8,9,10)
Tip

一個 thread 要同時呼叫 co_rank(k_curr)co_rank(k_next):前者給輸入子陣列的起點,後者給終點。長度 = i_next - i_curr(A)與 j_next - j_curr(B)。


Binary-Search 實作 (Binary-Search Implementation)

int co_rank(int k, int* A, int m, int* B, int n) {
    int i      = (k < m) ? k : m;            // 上界: i = min(k, m)
    int j      = k - i;                      // 維持 i + j = k
    int i_low  = (0 > (k - n)) ? 0 : (k - n);// 下界: i_low = max(0, k-n)
    int j_low  = (0 > (k - m)) ? 0 : (k - m);// 下界: j_low = max(0, k-m)
    int delta;
    bool active = true;
    while (active) {
        if (i > 0 && j < n && A[i-1] > B[j]) {          // i 太大
            delta = (i - i_low + 1) >> 1;               // ceil((i-i_low)/2)
            j_low = j;  j = j + delta;  i = i - delta;  // 往小調 i, 補大 j
        } else if (j > 0 && i < m && B[j-1] >= A[i]) {  // j 太大
            delta = (j - j_low + 1) >> 1;
            i_low = i;  i = i + delta;  j = j - delta;  // 往小調 j, 補大 i
        } else {
            active = false;                             // A[i-1]<=B[j] 且 B[j-1]<A[i]
        }
    }
    return i;
}

邊界初始化為何不只是 0?

變數 理由
i (上界) min(k, m) i 不能超過 A 的大小 m,也不能超過 k
i_low (下界) max(0, k-n) B 最多貢獻 n 個元素,故 A 至少要貢獻 k-n
j_low (下界) max(0, k-m) 對稱論證:B 至少貢獻 k-m
Warning

i_low/j_low 設成 0 仍正確,只是搜尋較慢。只有當 k > n(或 k > m)時,下界才會大於 0 並真正縮小範圍。

Trace:thread 1 呼叫 co_rank(3, A, 5, B, 4)

A=[1,7,8,9,10]B=[7,10,10,12]

iter i_low i j j_low A[i-1] B[j] B[j-1] A[i] 判定 delta
0 0 3 0 0 8 7 8 > 7i 太大 (3-0+1)>>1 = 2
1 0 1 2 0 1 10 10 7 10 >= 7j 太大 (2-0+1)>>1 = 1
2 1 2 1 0 7 10 7 8 7<=107<8接受

回傳 i = 2,caller 推得 j = k - i = 3 - 2 = 1

Important

接受條件不對稱A[i-1] <= B[j](用 <=)但 B[j-1] < A[i](用 <)。這是為了維持 stability——平手時 A 元素先放。對應到 code 裡第二個 if 用 B[j-1] >= A[i] 判定「j 太大」。若把比較寫成對稱,會破壞穩定性。

Tip

deltaceil+1 再右移)而非 floor,確保每一步至少縮減 1,避免區間卡住、保證 while-loop 終止。


Basic Parallel Merge Kernel (Basic Global-Memory Kernel)

假設 A、B、C 全在 global memory。每個 thread 算出自己的輸出區段,再用 co_rank 反查輸入區段,最後跑 merge_sequential

__global__ void merge_basic_kernel(int* A, int m, int* B, int n, int* C) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    int elems = ceil((m + n) / (float)(blockDim.x * gridDim.x));
    int k_curr = tid * elems;                       // 本 thread 輸出起點 rank
    int k_next = min((tid + 1) * elems, m + n);     // 下一 thread 起點 (=本 thread 終點)

    int i_curr = co_rank(k_curr, A, m, B, n);       // 第 1 次: 輸入起點
    int i_next = co_rank(k_next, A, m, B, n);       // 第 2 次: 輸入終點
    int j_curr = k_curr - i_curr;
    int j_next = k_next - i_next;

    // 各 thread 獨立合併自己的子陣列
    merge_sequential(&A[i_curr], i_next - i_curr,
                     &B[j_curr], j_next - j_curr,
                     &C[k_curr]);
}
Warning

總輸出數 m+n 不一定是 thread 數的倍數 → k_next 必須 min(..., m+n) 做 clamp,否則越界。

範例對照(3 threads)

thread k_curr k_next (i_curr,j_curr) (i_next,j_next) merge 呼叫
0 0 3 (0,0) (2,1) merge(&A[0],2, &B[0],1, &C[0])
1 3 6 (2,1) (5,1) merge(&A[2],3, &B[1],0, &C[3])
2 6 9 (5,1) (5,4) merge(&A[5],0, &B[1],3, &C[6])

thread 1 的 n=0:輸出 C[3..5] 全部來自 A(無 B 元素)——展示輸入用量完全由資料值決定。


Coalescing 弱點 (Coalescing Weaknesses)

basic kernel 簡潔但記憶體存取效率差,兩個來源:

  1. merge_sequential 階段:warp 內相鄰 thread 讀/寫的不是相鄰位址。
  2. co_rank 階段:binary search 為不規則 (irregular) 存取,幾乎不可能 coalesced,且直接打在 global memory。
merge_sequential 第一個 iteration (上面範例):
warp lanes:    t0     t1     t2
   read:      A[0]   A[2]   B[0]    <- 位址不連續 (stride != 1) => 未 coalesced
   write:     C[0]   C[3]   C[6]    <- stride = elems => 未 coalesced
比較項 Basic kernel (本節) Tiled kernel (下一節)
資料位置 A/B/C 全在 global memory 先 coalesced 載入 shared memory
co_rank 存取 每 thread 在 global memory 做 binary search block 層級先算,再於 shared memory 搜尋
合併存取 未 coalesced tile 載入 coalesced
程式複雜度 較高(buffer 管理)
Tip

改善策略選的是 PMPP Ch.6 三招中的第三招:用 coalesced 方式把資料搬進 shared memory,再在 shared memory 做不規則存取。詳見 12-Merge/03-Tiled-and-Circular-Buffer-Merge-Kernels


關鍵公式 (Key Formulas)

名稱 公式
co-rank 不變式 i + j = k
搜尋複雜度 O(log₂ N)N = max(m, n)
sequential merge(每 thread) O((m+n) / #threads)
上界初始化 i = min(k, m)j = k - i
下界初始化 i_low = max(0, k-n)j_low = max(0, k-m)
半分步長 delta = (i - i_low + 1) >> 1 = ceil((i - i_low)/2)
每 thread 輸出量 elems = ceil((m+n)/(blockDim.x * gridDim.x))
輸出邊界 rank k_curr = tid*elemsk_next = min((tid+1)*elems, m+n)

考試/面試重點 (Exam / Test Patterns)

情境 / 關鍵字 答案 / 技巧
「co_rank 回傳什麼?」 只回傳 ij = k - i 由 caller 推導
「co_rank 的不變式」 i + j = k,每步調整 marker 都維持
「為何用 binary search?」 兩輸入皆 sorted → 可在 O(log₂N) 內找 co-rank(也可 higher-radix)
「接受條件為何不對稱(<= vs <)?」 維持 stability:平手時 A 元素先放入 C
「i_low / j_low 不設 0 的好處?」 縮小搜尋範圍加速;當 k>n 時 i 至少為 k-n,下界才生效
「delta 為何用 ceil?」 保證每步至少縮 1,避免區間停滯、確保終止
「basic kernel 為何呼叫 co_rank 兩次?」 k_curr 取輸入起點、k_next 取輸入終點;長度 = next−curr
「k_next 為何要 clamp?」 m+n 不一定整除 thread 數,避免越界
「basic kernel 最大缺點?」 merge 與 co_rank 存取未 coalesced,浪費 bandwidth
「哪幾個 thread 在 global memory 做 binary search?(Ex.4a)」 basic kernel 中每一個 thread 都做 → 全部 thread
「merge 與其他 pattern 本質差異?」 輸入範圍資料相依,需 dynamic input data identification