Stanford CS336 Day 9 - Kernels, Triton

今日進度:6. Kernels, Triton
今日花費時數:4

筆記

Benchmarking & Profiling

在撰寫高效能的深度學習或 GPU 程式碼時,最重要的原則之一就是永遠要對程式碼進行基準測試與效能分析。與其花費數小時盲目優化自認為是瓶頸的程式碼,不如依靠測量工具精確找出需要優化的地方。

基準測試 (Benchmarking)

基準測試是用來測量整個 operation 或 function 端到端執行的實際時間(wall-clock time)。它可以告訴我們程式執行得有多快,但在測量 PyTorch 與 GPU 程式碼時,必須避開幾個常見的陷阱:

  1. 必須執行「暖身」 (Warm-up iterations): 當 PyTorch 第一次將任務派發到 GPU 時,背景會進行機器碼編譯、初始化等前置作業。如果直接測量第一次執行,我們測量到的其實是啟動開銷。因此,必須先跑幾次暖身迭代,才能準確測量程式在穩定狀態下的實際執行速度
  2. 務必強制同步 (Synchronization): CPU 與 GPU 是兩個獨立運作的計算單元。在執行程式時,CPU 將 CUDA kernel 發射 (launch) 給 GPU 後,就會繼續往下跑,不會停下來等待 GPU 算完。如果不加上同步指令,測量到的就只是「CPU 派發任務的時間」,而非「GPU 實際運算的時間」。因此,在計時的開始與結束前,必須呼叫 torch.cuda.synchronize(),確保 CPU 等待 GPU 完成所有工作,讓兩者狀態一致。
  3. 多次測量取平均值: 單次測量容易受到 GPU 溫度(散熱屬性)等因素影響而產生波動。正確的做法是執行多次測量,然後回傳平均值。

下方是做 benchmark 的程式碼範例

def benchmark(description: str, run: Callable, num_warmups: int = 1, num_trials: int = 3):
    """Benchmark `func` by running it `num_trials`, and return all the times."""
    # Warmup: first times might be slower due to compilation, things not cached.
    # Since we will run the kernel multiple times, the timing that matters is steady state.
    for _ in range(num_warmups):
        run()
    if torch.cuda.is_available():
        torch.cuda.synchronize()  # Wait for CUDA threads to finish (important!)

    # Time it for real now!
    times: list[float] = [] # @inspect times, @inspect description
    for trial in range(num_trials):  # Do it multiple times to capture variance
        start_time = time.time()

        run()  # Actually perform computation
        if torch.cuda.is_available():
            torch.cuda.synchronize()  # Wait for CUDA threads to finish (important!)

        end_time = time.time()
        times.append((end_time - start_time) * 1000)

    mean_time = mean(times)
    return mean_time

實務上除了自己手寫迴圈,也可以直接使用官方提供的 torch.utils.benchmark,它在底層已經自動幫開發者處理好了上述的暖身、同步與多次測量平均等防護機制。

效能分析 (Profiling)

基準測試屬於較粗粒度的工具,只能告訴你程式「很慢」,但無法告訴你「為什麼慢」。效能分析則是細粒度的工具,能深入函式內部,讓你看到時間具體花費在哪些底層調用上

  1. PyTorch 內建 Profiler

    • 適用場景: 針對基本操作提供良好的分析,且不需離開 Python 環境。
    • 底層執行細節: 即使在 Python 中只是呼叫簡單的 a + b 或矩陣相乘,Profiler 能揭露其底層的冰山一角。例如,它會先呼叫 PyTorch 的 C++ 介面 (aten wrapper),再派發給特定的 CUDA kernel(如矩陣相乘使用的 cutlass 函式庫),最後還有 kernel launch 與設備同步的時間。
    • 分析資源分配: Profiler 能清楚顯示 CPU 執行時間與 GPU (CUDA) 執行時間的佔比,協助我們判斷計算是受限於記憶體還是運算量。
    • 動態核心派發 (Dynamic Dispatch): Profiler 顯示 PyTorch 會根據資料規模選擇最適合的底層實作。例如大矩陣 (dim=2048) 會呼叫 cutlass,小矩陣 (dim=128) 則可能派發給 xmma 等不同的底層函式。
    • 揭露操作的本質 (Composite vs. Fused):
      • 複合操作: 像是 torch.cdist,底層會被拆解成多個 kernel(乘法、次方、加總)分別執行。
      • 融合操作: 像是 gelusoftmax,雖然包含多步數學運算,但在底層是由單一的 fused kernel 完成,免去了頻繁的記憶體傳輸。
    • 火焰圖與視覺化: 透過設定 with_stack=True,Profiler 可以輸出堆疊追蹤並繪製成 SVG 火焰圖,幫助視覺化複雜模型(如 MLP)的時間花費。

    下方是使用 Pytorch Profiler 進行 profiling 的範例

    from torch.profiler import ProfilerActivity
    
    def profile(description: str, run: Callable, num_warmups: int = 1, with_stack: bool = False):
        # Warmup
        for _ in range(num_warmups):
            run()
        if torch.cuda.is_available():
            torch.cuda.synchronize()  # Wait for CUDA threads to finish (important!)
    
        # Run the code with the profiler
        with torch.profiler.profile(
                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                # Output stack trace for visualization
                with_stack=with_stack,
                # Needed to export stack trace for visualization
                experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)) as prof:
            run()
            if torch.cuda.is_available():
                torch.cuda.synchronize()  # Wait for CUDA threads to finish (important!)
    
        # Print out table
        table = prof.key_averages().table(sort_by="cuda_time_total",
                                          max_name_column_width=80,
                                          row_limit=10)
        #text(f"## {description}")
        #text(table, verbatim=True)
    
        # Write stack trace visualization
        if with_stack:
            text_path = f"var/stacks_{description}.txt"
            svg_path = f"var/stacks_{description}.svg"
            prof.export_stacks(text_path, "self_cuda_time_total")
    
        return table
    
  2. NVIDIA Nsight Systems (NSYS)

    當面對像是 MLP 這種較複雜的模型時,PyTorch Profiler 的視覺化可能不足以解釋所有時間的去向,這時就需要使用像 NVIDIA NSYS 這種進階的硬體效能分析工具。

    • 視覺化時間軸: 它能同時呈現 CPU 與 GPU 的執行狀態。我們可以清楚看到 CPU 的執行進度往往遠遠超前 GPU。CPU 會預先將後續的任務丟進佇列裡讓 GPU 消化,這就是為什麼儘管 Python 本身效能不高,卻不會成為整體效能瓶頸的原因。
    • NVTX 標註功能: 我們可以在 Python 程式碼中加入標註(例如 nvtx.range_pushwith nvtx.range()),這樣在分析器的介面中就能精準對應出特定區塊(例如定義模型、特定的訓練步驟)的硬體時間花費。
    • 觀察記憶體與初始化開銷: NSYS 還能追蹤記憶體使用量。透過圖表可以清楚看到模型建構時記憶體如何分配,以及發現程式剛啟動時,往往有極高的時間比例(有時高達數秒)純粹耗費在載入函式庫與 JIT 編譯上。
    • Kernel 執行時間總結: 在 NSYS 介面中,可以選取特定範圍(例如排除暖身步數後的穩定區間),統整出該區間內所有 Kernel 的執行總時間,精準找出最耗時的算子。
    • 找出隱藏的效能殺手: 透過進階 profiler,我們可以發現一些直覺難以察覺的瓶頸。例如,在訓練迴圈中頻繁印出 loss (print(loss)) 會對效能造成巨大打擊。因為要印出 loss,CPU 就必須停下來等待 GPU 把 loss 算完傳回來,這會觸發強制同步 (cudaStreamSynchronize),中斷了 CPU 提前派發任務的機制,進而導致硬體閒置與效能低落。

    下方是使用 NSYS 進行 profiling 的範例

    import torch.cuda.nvtx as nvtx
    
    def run_mlp(dim: int, num_layers: int, batch_size: int, num_steps: int, use_optimizer: bool = False):
        """Run forward and backward passes through an MLP.
        
        Args:
            dim: Dimension of each layer
            num_layers: Number of linear+GeLU layers
            batch_size: Number of samples to process at once
            num_steps: Number of forward/backward iterations
            use_optimizer: Whether to use Adam optimizer for weight updates
        """
        # Define a model (with random weights)
        with nvtx.range("define_model"):
            model = MLP(dim, num_layers).to(get_device())
        
        # Initialize optimizer if requested
        optimizer = torch.optim.Adam(model.parameters()) if use_optimizer else None
    
        # Define an input (random)
        with nvtx.range("define_input"):
            x = torch.randn(batch_size, dim, device=get_device())
    
        # Run the model `num_steps` times
        for step in range(num_steps):
            if step > 10:
                # start profiling after 10 warmup iterations
                torch.cuda.cudart().cudaProfilerStart()
    
            nvtx.range_push(f"step_{step}")
            
            # Zero gradients
            if use_optimizer:
                optimizer.zero_grad()
            else:
                model.zero_grad(set_to_none=True)
    
            # Forward
            with nvtx.range("forward"):
                y = model(x).mean()
    
            # Backward
            with nvtx.range("backward"):
                y.backward()
    
            # Optimizer step if enabled
            if use_optimizer:
                with nvtx.range("optimizer_step"):
                    #print(f"Step {step}, loss: {y.item():.6f}")
                    optimizer.step()
            
            nvtx.range_pop()
    
    

Kernel Fusion (算子融合) 的動機與原理

在知道如何利用 Benchmarking 與 Profiling 找出效能瓶頸後,我們常會發現:許多操作的效能低落並非因為計算量太大,而是受限於記憶體頻寬 (memory-bound)。為了解決這個問題,最重要的技術就是 kernel fusion

核心概念

Horace He 在其部落格文章中提出了一個非常生動的譬喻來解釋硬體的記憶體階層:

  • Warehouse = DRAM (Global Memory): 容量極大,但存取速度極慢。
  • Factory = SRAM / Registers (SM 內部的本地記憶體): 容量很小,但處理速度極快。

Unfused 的操作陷阱: 如果我們在 Python 中將操作分開寫(例如:先做加法,再做乘法),每一次的數學運算都需要:從倉庫搬資料到工廠 (read)在工廠計算 (compute)將結果搬回倉庫 (write)。當有多個連續操作時,這個昂貴的「搬運成本」會被重複支付無數次。

Fused 的優勢: 如果我們將這些操作「融合」成單一的 kernel,我們只需要:把資料從倉庫搬進工廠一次在工廠內部連續完成所有計算最後只把最終結果搬回倉庫一次。這大幅減少了記憶體的讀寫次數。

實例證明:GeLU activation function

透過實作神經網路常用的 GeLU (Gaussian Error Linear Unit) 來展示融合的巨大威力。GeLU 包含了乘法、加法、次方與 tanh 等多步數學操作。

  • 實作方法一:手動拆解實作
    • 寫法:

      def manual_gelu(x: torch.Tensor):
          return 0.5 * x * (1 + torch.tanh(0.79788456 * (x + 0.044715 * x * x * x)))
      
    • Profiling 觀察: Profiler 顯示這段程式碼在底層啟動了大量獨立的 CUDA kernels(因為每個加法或乘法都是一個獨立的 kernel)。

    • 效能結果: 執行時間高達 8.1 毫秒

  • 實作方法二:PyTorch 原生實作
    • 寫法: 直接呼叫 torch.nn.functional.gelu

      def pytorch_gelu(x: torch.Tensor):
          # Use the tanh approximation to match our implementation
          return torch.nn.functional.gelu(x, approximate="tanh")
      
    • Profiling 觀察: 雖然數學步驟一樣多,但 Profiler 顯示它在底層只啟動了一個高度優化的 CUDA kernel

    • 效能結果: 執行時間僅需 1.1 毫秒,比手動拆解版快了將近 8 倍

透過 benchmarking 與 profiling,我們得出一個重要結論:即使計算結果在數值上完全相同,底層是否將操作「融合 (fuse)」,對效能有著決定性的影響

這也是為甚麼我們接下來要學習如何親自使用 C++ CUDA、Triton 或是利用 torch.compile 來撰寫自己的 fused kernel。實務上,現代的 JIT 編譯器已經能自動幫我們處理許多基礎優化,但當我們發明了「全新的網路架構」或遇到編譯器無法最佳化的「複雜自定義操作」(例如 Flash Attention 的底層優化)時,親自撰寫並融合 kernel 便是榨出極致效能的關鍵。

CUDA Kernels

為了解決記憶體頻寬瓶頸,並實現 kernel fusion,最直接的方法就是使用 C++ 與 CUDA API 親自撰寫能在 GPU 上執行的程式碼。

CUDA 的執行階層模型 (Execution Hierarchy Model)

CUDA 是 C/C++ 的擴充,提供了管理 GPU 的 API。在設計 CUDA 程式時,最簡化的概念就是:我們只要寫出針對單一元素 i 的操作邏輯 f(i),CUDA 就會自動平行地為所有 i 執行這個操作

為了達成高度平行化,CUDA 採用了階層式的執行架構:

  • Grid: 一個 grid 代表一次 kernel 啟動的總任務,它包含了多個 thread blocks。
  • Thread Block: 包含了一群 threads。同一個 block 必定會被排程分配到同一個 streaming multiprocessor 上執行。同一個 block 內的 threads 可以共享極快的 shared memory 並進行同步。
  • thread: 最小的實際運算單元。

定位機制: 由於所有 thread 執行的程式碼邏輯都是一樣的,我們必須利用內建的變數 (blockIdx, blockDim, threadIdx) 進行座標計算,來決定當前這個 thread 應該處理記憶體中的哪一筆資料。

手寫 CUDA GeLU 程式碼

我們將實際的 CUDA 實作拆分為兩個主要部分來理解:負責在 CPU 上派發任務的 Wrapper function,以及真正在 GPU 內部計算的 kernel。

#include <math.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAException.h>

// The __global__ keyword distinguishes this as a CUDA kernel function that runs on the GPU
__global__ void gelu_kernel(float* in, float* out, int num_elements) {
    // Get the index into the tensor (block start position + relative thread offset)
    int i = blockIdx.x * blockDim.x + threadIdx.x;

    if (i < num_elements) {  // To handle the case when n < numBlocks * blockDim (boundary protection)
        // Do the actual computation (Fused Computation: intermediate results stay in fast registers)
        out[i] = 0.5 * in[i] * (1.0 + tanh(0.79788456 * (in[i] + 0.044715 * in[i] * in[i] * in[i])));
    }
}

inline unsigned int cdiv(unsigned int a, unsigned int b) {
    // Compute ceil(a / b) to ensure all elements are covered by the blocks
    return (a + b - 1) / b;
}

// CPU wrapper function that orchestrates the launch of the kernel
torch::Tensor gelu(torch::Tensor x) {
    TORCH_CHECK(x.device().is_cuda());
    // Ensure x lives in a contiguous block of memory (e.g., watch out for transposed tensors)
    TORCH_CHECK(x.is_contiguous());

    // Allocate empty tensor (saves operations compared to using zeros_like)
    torch::Tensor y = torch::empty_like(x);

    // Determine grid (elements divided into blocks)
    int num_elements = x.numel();
    int block_size = 1024;  // Number of threads
    int num_blocks = cdiv(num_elements, block_size);

    // Launch the kernel (passing memory pointers, not the tensor values)
    gelu_kernel<<<num_blocks, block_size>>>(x.data_ptr<float>(), y.data_ptr<float>(), num_elements);
    C10_CUDA_KERNEL_LAUNCH_CHECK();  // Catch errors immediately (requires CUDA_LAUNCH_BLOCKING=1)

    return y;
}

CPU 端 Wrapper function 解析:

  • 連續記憶體檢查 (x.is_contiguous()):這一步非常重要!因為我們在 kernel 中會進行指標座標計算,必須假設張量 x 在底層記憶體中是一塊連續的空間。
  • 高效分配輸出空間 (torch::empty_like(x)):故意不使用 zeros_like 來避免浪費效能,因為我們計算時必定會覆寫每一個位置的數值。
  • 計算 grid 尺寸與無條件進位 (cdiv):我們將總元素數量除以 block size (1024) 並使用無條件進位 (cdiv),確保最後尾數的元素也被分配到獨立的 block 中處理。
  • Launch kernel (<<<num_blocks, block_size>>>):這對角括號是 CUDA 特有的語法,用來定義 grid 與 block 的執行配置,並傳入記憶體指標 (data_ptr)。

GPU 端 CUDA kernel 解析:

  • 帶有 __global__ 關鍵字:代表它是由 CPU 呼叫,但完全運行在 GPU 上的程式碼
  • 找出自己負責的座標 (int i = ...):公式為 blockIdx.x * blockDim.x + threadIdx.x,亦即「找出所屬 block 的起始位置,再加上該 block 中的相對偏移量」。
  • 關鍵的邊界保護 (if (i < num_elements)):因為分配 block 時使用了無條件進位,最後一個 block 尾端通常會有多餘的 threads。加上此條件可防止它們存取到非法的越界記憶體。
  • 執行算子融合 (fused computation):在 if 區塊內,我們連續執行所有數學運算。因為這些全部發生在同一個 kernel 內的一個 thread 中,所有中間結果都直接存在 GPU 最底層的暫存器 (registers) 內,直到最終結果出爐才寫回 DRAM。

在 Python 中編譯與呼叫 CUDA Kernel

寫好 C++ 與 CUDA 程式碼後,我們不需要離開 Python 環境或使用命令列去手動編譯。我們可以利用 PyTorch 提供的 load_inline 工具,在 Python 執行期間動態編譯並載入我們的 kernel。

import os
import torch
from torch.utils.cpp_extension import load_inline

def create_cuda_gelu():    
    # Set CUDA_LAUNCH_BLOCKING so that if there are errors, CUDA will tell you what went wrong.
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

    # The `load_inline` function makes it convenient to write CUDA code and bind it to a Python module for immediate use.

    # CUDA code: has the full logic
    cuda_gelu_src = open("gelu.cu").read()

    # C++ code: defines the gelu function
    cpp_gelu_src = "torch::Tensor gelu(torch::Tensor x);"

    # Compile the CUDA code and bind it to a Python module.
    ensure_directory_exists("var/cuda_gelu")
    if not torch.cuda.is_available():
        return None
        
    module = load_inline(
        cuda_sources=[cuda_gelu_src],
        cpp_sources=[cpp_gelu_src],
        functions=["gelu"],
        extra_cflags=["-O2"],
        verbose=True,
        name="inline_gelu",
        build_directory="var/cuda_gelu",
    )

    cuda_gelu = getattr(module, "gelu")
    return cuda_gelu
    
my_cuda_gelu = create_cuda_gelu()
x = torch.randn(16384, device="cuda")
y = my_cuda_gelu(x) 
  • 環境變數 CUDA_LAUNCH_BLOCKING = "1" 這在開發 CUDA 程式時非常重要!因為 CPU 與 GPU 是非同步執行的,設定這個環境變數可以強制同步,當錯誤發生時,CUDA 才能正確回報錯誤發生的位置與原因,否則除錯會是一場災難。(備註:但強制同步會帶來額外的 runtime 效能開銷,因此僅建議在 debug 階段開啟此設定。)

效能結果與觀察 (以 GeLU 為例)

我們將手寫的 CUDA GeLU 版本與前面的版本進行 Benchmark 與 Profiling 對比:

  • 手動拆解 Python 版:8.1 毫秒
  • 手寫 CUDA 版 (fused):1.84 毫秒
  • PyTorch 原生版 (fused):1.1 毫秒

Profiler 清楚顯示,我們成功實現了 kernel fusion!手寫的 CUDA 版本在底層只啟動了單一個 kernel,並且佔用了 100% 的 GPU 計算時間,不再有頻繁的讀寫開銷。

侷限性

雖然我們的 CUDA 實作比 Python 手動拆解版快了超過 4 倍,但仍略慢於 PyTorch 官方高度優化的 1.1 毫秒版本。 此外,像 GeLU 這種 element-wise operations 在 CUDA 中相對簡單,因為各個 thread 之間互不依賴。但如果我們要實作像矩陣相乘或 softmax(需要跨元素加總的 reduction 操作),就需要多個數值之間的互動,這時就必須自行管理複雜的 shared memory 與 thread 之間的同步機制。

Triton kernels

寫 C++ CUDA 雖然能達到極致效能,但對於大多數深度學習研究者來說,開發門檻較高且難以除錯。由 OpenAI 於 2021 年開發的 Triton(一種基於 Python 的領域特定語言)提供了一個完美的平衡點:讓開發者能留在熟悉的 Python 環境中,卻能寫出效能媲美手寫 CUDA 的 GPU 核心

Triton 的核心優勢與抽象化模型

  • 從 ”thread” 提升至 ”block” 等級的思考: 在手寫 CUDA 時,我們必須以單一 thread 為單位來計算座標 (blockIdx, threadIdx)。而在 Triton 中,我們不需要管理單一 thread,而是直接以 thread block 為單位進行思考與撰寫。Triton 編譯器會自動在底層幫我們進行「執行緒粗化 (thread coarsening)」等優化。
  • 自動處理繁瑣的硬體細節: Triton 最強大的地方在於,編譯器會自動幫我們接管許多在 CUDA 中需要手動調校的複雜機制:
    • 記憶體合併 (Memory Coalescing): 自動將相鄰的記憶體讀取請求打包(例如一次抓取 4 個連續元素),最大化 DRAM 頻寬利用率。
    • 共享記憶體管理 (Shared Memory Management): 在處理需要跨元素溝通的操作時,Triton 會自動分配與管理 streaming multiprocessor 內部的共享記憶體。
    • Streaming Mutiprocessors 內部排程 (Scheduling within SMs): 自動處理 SM 內部的執行緒同步與排程。
    • (註:唯一的例外是「跨 SM 的排程 (Scheduling across SMs)」仍需要手動設計 Grid 的分配)
  • 強大的debug能力: 因為 Triton 程式碼本質上就是 Python,我們可以直接在 Python 中逐行執行進行 debug,這在傳統 CUDA 開發中是極度困難的。

What does Triton offer?

CUDA Triton
Memory coalescing (transfer from DRAM) manual automatic
Shared memory management manual automatic
Scheduling within SMs manual automatic
Scheduling across SMs manual manual

深入解析:Triton 版 GeLU 程式碼實作

我們將程式碼分為在 GPU 上執行的 triton_gelu_kernel 以及在 CPU 上呼叫它的 Wrapper triton_gelu

import torch
import triton
import triton.language as tl
import os

# ========== GPU Triton Kernel (GeLU) ==========
@triton.jit
def triton_gelu_kernel(x_ptr, y_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
    # The mental model here is that we are programming at the block level, not the thread level.
    #     |        Block 0            |          Block 1          |      ...      |
    #                            BLOCK_SIZE                                 num_elements

    # Find the start position of the current block.
    # We get the block ID and multiply it by the block size to find our starting coordinate.
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE

    # Calculate the vectorized offsets for this entire block.
    # Unlike CUDA where we compute a single coordinate, here our offsets are a vector of indices.
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # Handle boundaries using a mask.
    # If the block goes off the edge of our array, we need a mask to ignore out-of-bounds elements.
    mask = offsets < num_elements

    # Read from global memory (DRAM) to the SM's local high-speed memory (Registers).
    # We load the entire block in a single vectorized operation.
    x = tl.load(x_ptr + offsets, mask=mask)

    # Compute the approximate GeLU.
    # Triton doesn't have a built-in tanh function, so we compute it manually using the formula:
    # tanh(a) = (exp(2a) - 1) / (exp(2a) + 1)
    a = 0.79788456 * (x + 0.044715 * x * x * x)
    exp = tl.exp(2 * a)
    tanh = (exp - 1) / (exp + 1)
    y = 0.5 * x * (1 + tanh)

    # Write the computed values back into our output buffer in global memory (DRAM).
    tl.store(y_ptr + offsets, y, mask=mask)

# ========== CPU Wrapper Function (GeLU) ==========
def triton_gelu(x: torch.Tensor):
    # Ensure the tensor is on the GPU device.
    assert x.is_cuda
    
    # Ensure the tensor lives in a contiguous block of memory. 
    # This is critical because we are going to do indexing arithmetic based on memory pointers.
    assert x.is_contiguous()

    # Allocate the output tensor. 
    # We use empty_like instead of zeros_like to save on extra operations, since we will overwrite these values anyway.
    y = torch.empty_like(x)

    # Determine the grid size.
    num_elements = x.numel()
    block_size = 1024  
    
    # Calculate the total number of blocks needed.
    # We take the ceiling of the division to round up and ensure the trailing elements are computed.
    num_blocks = triton.cdiv(num_elements, block_size)

    # Launch the Triton kernel.
    # We pass the grid size and the block size as a compile-time constant (constexpr).
    triton_gelu_kernel[(num_blocks,)](x, y, num_elements, BLOCK_SIZE=block_size)

    return y

Triton Kernel 核心觀念解析:

  • 取得 Block ID (tl.program_id): 我們不再索取 thread ID,而是直接拿 block ID (pid) 來計算當前 block 在整個陣列中的起始位置 block_start
  • 向量化偏移量 (tl.arange): 這是與 CUDA 最大差異之處!在 CUDA 中我們計算的是一個單一整數座標 i;但在 Triton 中,我們透過 tl.arange(0, BLOCK_SIZE) 產生了一個陣列向量。這意味著我們程式碼的邏輯是一次針對整個 block (例如 1024 個元素) 同時進行操作。
  • 遮罩保護 (mask): 因為操作是向量化的,當最後一個 block 的長度超出實際元素總數時,我們利用 offsets < num_elements 建立一個 boolean mask,並將其傳入 tl.loadtl.store 中,Triton 就會自動幫我們忽略越界的部分。

進階實作:跨元素加總 (Reduction) 操作 — 以 Softmax 為例

GeLU 屬於逐元素 (element-wise) 操作,實作相對簡單。但如果我們要實作 softmax,它需要跨元素尋找最大值與加總 (reduction),這牽涉到複雜的通訊。我們可以使用以下方式實作 Triton Softmax:

# ========== GPU Triton Kernel (Softmax) ==========
@triton.jit
def triton_softmax_kernel(x_ptr, y_ptr, stride_row, num_cols, BLOCK_SIZE: tl.constexpr):
    # Get the row ID currently being processed.
    # Our grid should actually just be rows, so each SM is going to handle a single row. This is the optimal block design [1].
    row_idx = tl.program_id(0)

    # Calculate the starting pointer for this row in memory using the stride.
    row_start_ptr = x_ptr + row_idx * stride_row

    # Get the column offsets and create a mask for boundary protection.
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < num_cols

    # Load the entire row into the SM's local memory at once. 
    # Out-of-bounds elements are masked with negative infinity for numerical safety [2].
    row = tl.load(row_start_ptr + col_offsets, mask=mask, other=-float('inf'))

    # Softmax computation logic (completed entirely within the high-speed local memory of the SM).
    # We subtract the max, take the exponent, sum it, and then divide to get the softmax normalized row [3].
    row_minus_max = row - tl.max(row, axis=0)
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_out = numerator / denominator

    # Write the softmax normalized row back to global memory (DRAM) [3].
    out_row_start_ptr = y_ptr + row_idx * stride_row
    tl.store(out_row_start_ptr + col_offsets, softmax_out, mask=mask)

# ========== CPU Wrapper Function (Softmax) ==========
def triton_softmax(x: torch.Tensor):
    # Allocate output tensor (no need to zero out).
    y = torch.empty_like(x)
    num_rows, num_cols = x.shape

    # The block size must include padding to cover the entire column.
    # The block size should be the number of columns plus a little bit of buffer to fit all the columns. 
    # Using triton.next_power_of_2 is a nice way of padding out the columns [2].
    block_size = triton.next_power_of_2(num_cols)

    # Each block is responsible for one row, so the number of blocks is exactly the number of rows [2].
    num_blocks = num_rows

    # Launch the Triton kernel.
    triton_softmax_kernel[(num_blocks,)](
        x_ptr=x, y_ptr=y, 
        stride_row=x.stride(0), 
        num_cols=num_cols, 
        BLOCK_SIZE=block_size
    )
    
    return y

Softmax 實作核心觀念:

  • block 完美對應 row: 最優雅的作法是讓 grid 直接對應矩陣的 rows,每個 streaming multiprocessor 負責處理一整個 row
  • 利用 streaming multiprocessor 內部的高速記憶體: 只要一整列的資料能完全塞進 streaming multiprocessor 的本地記憶體中,我們就可以直接在內部極速算出最大值、相減、取指數、加總 與除法,最後再寫回 global memory。

PTX 底層機器碼揭秘

Triton 之所以快,是因為它能編譯出極度優化的 PTX (Parallel Thread Execution)。PTX 就像是 GPU 的組合語言。我們可以寫一個腳本來觀察 Triton 究竟產生了什麼樣的底層指令:

def print_ptx_main():
    # PTX (parallel thread execution) is like an assembly language for GPUs.
    # We can see the PTX code generated by Triton.
    ptx = print_ptx("triton_gelu", triton_gelu_kernel)
    print(ptx)

def print_ptx(name: str, kernel):
    # PTX is not generated when in interpret mode.
    if os.environ.get("TRITON_INTERPRET") == "1":
        return None
    ptx_path = f"var/{name}-ptx.txt"
    with open(ptx_path, "w") as f:
        return list(kernel.cache.values()).asm["ptx"]

觀察生成的 PTX 機器碼可以發現幾項重點:

  1. 記憶體存取 (ld.global. / st.global.*):* 這是對 global memory (DRAM) 讀寫的指令。仔細看會發現編譯器自動將其優化為一次載入 4 個連續數值,完美達成了記憶體合併 (memory coalescing),大幅節省了 DRAM 頻寬。
  2. 硬體指標對應: %ctaid.x 對應到了 block ID,而 %tid.x 則是 thread 的 id。
  3. 暫存器的高效利用: 程式碼中充滿了 %f* (浮點數暫存器) 與 %r* (整數暫存器)。我們寫的數學運算,全部都被放在這些 GPU 內部最快的高速暫存器中執行,直到最後一刻才寫回 DRAM。
  4. 執行緒粗化 (Thread Coarsening): 編譯器自動調整了工作負載,讓單一 thread 同時處理多個元素(例如 8 個或 4 個),以分攤控制邏輯的開銷。

效能結果與總結

  • 若要獲得真實的高效能並產生 PTX,必須確保環境變數 TRITON_INTERPRET=“0”(若設為 1 則處於純 Python 直譯除錯模式,會非常慢且不產生 PTX)。
  • Benchmark 執行時間對比 (以 GeLU 為例):
    • 手動拆解 Python 版:8.1 毫秒
    • 手寫 Triton 版:1.848 毫秒
    • 手寫 CUDA 版:1.84 毫秒
    • 自動編譯版 (torch.compile):1.47 毫秒
    • PyTorch 原生版 (C++ 高度優化):1.1 毫秒
  • Benchmark 執行時間對比 (以 softmax 為例):
    • 手動拆解 Python 版:3.7 毫秒
    • 手寫 Triton 版:1.9 毫秒
    • PyTorch 原生版:1.5 毫秒
    • 自動編譯版 (torch.compile):1.3 毫秒

實務結論: 現代的 JIT 編譯器非常強大!torch.compile 在底層其實也是自動幫我們生成 Triton 程式碼,效能甚至超越了我們自己手寫的 Triton,在 softmax 上甚至超越了原生的 PyTorch 實作。因此,在實務上我們不需要將語言模型的所有部分都用 CUDA 重寫,只有在開發全新架構或遇到 JIT 無法處理的特殊硬體層級優化(如 FlashAttention)時,才是真正需要親自撰寫 Triton 的時機

今日回顧與筆者的碎碎念

這兩天的內容差不多已經殺死筆者一半的腦細胞,也因此大部分內容都要依靠 NotebookLM 來撰寫筆記。而且這些讓人看了會當機的內容看起來應該會持續到後面兩天 Parallelism 的部分才結束。跟前面不同的是,筆者大部分內容在做筆記的時候大概能至少懂 80%,但這兩天的內容整體來說只有理解了 50~60%。筆者預計開始動 assignment 2 的時候,才會開始根據作業內容認真把尚未理解的部分弄懂。

3個讚