Stanford CS336 Day 20 - Alignment - RL 2

今日進度:17. Alignment - RL 2
今日花費時數:6

筆記

上一節課我們從 verifiable rewards 的角度對 RL 進行了概述,這節課則是要深入探討 policy gradient的運作機制。

RL Setup for Language Models

  • State s:由 prompt 加上目前為止所生成的回應共同組成。與傳統機器人領域受限於現實物理世界不同,語言模型的 state 完全由其生成的 tokens 構成,因此具有極大的靈活性,可以透過生成字詞來達到任何想要的狀態。
  • Action a:生成下一個 token。(註:在後續的 policy gradient 數學表示中,為了簡便,通常會將 a 擴展代指「整個生成的回應」)。
  • Reward R:用來衡量生成回應的品質。這裡主要會聚焦於以下特性:
    • Outcome rewards:reward 是根據「完整的回應」來評估,而不是在生成過程中逐步給予。(註:雖然也可以使用過程獎勵 Process rewards,但在語言模型中極難實作與評估)。
    • Verifiable rewards:reward 的計算是確定性的 (deterministic),能透過特定函數(如程式碼或規則)直接算出結果並給予分數,不需要依賴人類的評估。
    • 在這種設定下,傳統強化學習中常見的 discounting 與 bootstrapping 較不適用。同時,這也導致了 reward 具有極度稀疏 (sparse) 且延遲的挑戰(例如只有在完整解出數學題後才能判斷對錯)。如果模型初始能力太差,可能只會獲得 0 分的獎勵,導致產生零梯度而完全無法更新參數
  • Transition probabilities / dynamic *T(s^\prime ∣s,a)*:在語言模型中非常簡單,下一個 state 單純就是目前的 state 加上剛生成的動作(*s^{\prime}=s+a*,屬於完全確定性的過程。
    • 因為我們能完全掌握轉移動態,這使得語言模型能夠進行 planning 或 test-time compute,這是傳統機器人領域夢寐以求的優勢。
  • Policy \pi(a∣s):就是語言模型本身,通常是由現有的 pre-training model 進行 fine-tuning 而來。
  • Rollout / Episode / Trajectory:流程為從初始狀態 (prompt) 開始,連續生成一連串的動作 (token),直到最後獲得一個單一的數值作為 reward(過程可表示為:$s\rarr a\rarr\dots\rarr a\rarr R$)。
  • Objective最大化預期獎勵 (Maximize expected reward, *\mathbb{E}[R]*)。這個期望值是建立在環境所給予的 prompt 分布 p(s),以及當前 policy(語言模型)所生成的 token 分布 \pi(a\mid s) 之上。

Policy Gradient

為了簡化符號,接下來會用 a 代表整個回應。

我們希望最佳化預期獎勵 \mathbb{E}[R],這在數學上可表示為 \mathbb{E}[R] = \int p(s)\,\pi(a \mid s)\,R(s,a) 也就是對 prompt s 的分布與模型生成的動作(回應)$a$ 的分布取期望值。

最直接的做法是計算梯度:

\nabla \mathbb{E}[R] = \int p(s)\,\nabla \pi(a \mid s)\,R(s,a)

\nabla \mathbb{E}[R] = \int p(s)\,\pi(a \mid s)\,\nabla \log \pi(a \mid s)\,R(s,a)

\nabla \mathbb{E}[R] = \mathbb{E}\big[\nabla \log \pi(a \mid s)\,R(s,a)\big]

透過對 \mathbb{E}[R] 取梯度,最終可以得到更新公式的核心為 \nabla \log \pi(a \mid s)\,R(s,a)。這與 SFT 非常相似,差別只在於 SFT 的目標是由人類提供標準答案,而 policy gradient 則是模型自己生成答案後,將 reward 來進行加權更新

這意味著模型會根據收集到的正向獎勵,不斷增加生成該正確回應的機率。同時,隨著 policy 更新,生成的回應也會隨著時間改變。

Policy gradient 的最大挑戰在於 high noise/variance。在語言模型的數學解題等任務中,獎勵往往是二元的(0 或 1)。如果初始模型能力太差,大部分的回應都會拿到 0 分。當 *R=0* 時,梯度更新量為 0。這代表在這種情況下,模型完全沒有獲得任何更新訊號,參數不會有任何改變,陷入「完全卡住」的困境。相較之下,RLHF 的 reward models (從 pairwise preferences 中學習) 是連續的。

Baselines

為了解決 high variance 的問題,這裡要引入 baseline 的概念:

  • 核心想法:不要直接最大化 \mathbb{E}[R],而是最大化扣除 baselin 後的 reward *\mathbb{E}[R-b(s)]*
  • baseline function b(s) 可以是任何 function,唯一的條件是它只能與 state(prompt)s 有關,不能與動作(生成的回應)a 有關。,也就是根據 \nabla \log \pi(a \mid s)\,(R(s,a)-b(s)) 來更新。
  • 因為動作機率分佈 *\pi(a\mid s)* 的積分為 1,這代表機率對參數的梯度總和為 0。既然 *b(s)*a 無關,將其乘上梯度並取期望值後也會是 0。從數學期望值的角度來看,減去一個常數項並不會改變原本 policy gradient 的期望值(unbiased estimate),但卻能大幅降低梯度更新時的 variance,進而加快收斂速度。
  • 解決了 0 分不更新的問題:如果原本的獎勵是 0 和 1,且 baseline 為 0.1。那麼拿到 0 分的回應在扣除 baseline 後會變成 -0.1,模型便能依據這個負值將更新方向推離錯誤的答案。
  • 雖然理論上有個最佳的 baseline function b^*(s) = \frac{\mathbb{E}\big[(\nabla \pi(a \mid s))^2 \, R \mid s\big]}{\mathbb{E}\big[(\nabla \pi(a \mid s))^2 \mid s\big]} ,但這但在高維度模型中極難計算。因此,實務上最常見的啟發式選擇是直接使用「給定 state 下的 expected reward(mean reward)」,即 *b(s)=\mathbb{E}[R∣s]*。然而這依然很難計算,只能估計。

Advantage Function

b(s) 選擇與 advantage function 有關

  • V(s) = \mathbb{E}[R | s] = state s 下的 expected reward。
  • *Q(s,a)=\mathbb{E}[R∣s,a]* 在 state s 下採取 action a 的 expected reward。(註:因為我們設定為 outcome rewards,且 *a* 包含整個回應,所以這裡的 *Q(s,a)* 就等同於 $R(s,a)$)

Advantage function 定義為 *A(s,a)=Q(s,a)−V(s)*,代表「採取動作 a 比該狀態的平均表現好多少」。因此,當我們選擇 *\mathbb{E}[R∣s]* 作為 baseline 時,我們所最佳化的 *R(s,a)−b(s)*,本質上就等同於 advantage function *A(s,a)*

所有 policy gradient 演算法(包含 PPO、GRPO 等)最終都可以歸納成一個通用的參數更新公式結構: \nabla\log\pi(a∣s)\delta, 其中 *\delta* 代表某種基於 reward計算出來的 advantage estimate。我們會在後面看到,\delta 有多種選擇。

Training Walkthrough

Group Relative Policy Optimization (GRPO) [Shao+ 2024]

1. Task & Reward Functions

在 RL 中,reward design 是一門藝術。為了能快速展示完整的訓練流程,這裡設計了一個極其簡單的任務:「排序 N 個數字」。
然而,如果我們給予是二元獎勵(即「完全排序正確得 1 分,否則得 0 分」),這會引發嚴重的 sparse rewards 問題。在模型初始能力極差時,幾乎所有的生成結果都會拿到 0 分,導致完全沒有梯度可以更新,模型就會「卡死」。

因此,我們必須設計能給予「部分分數(partial credit)」的機制:

  1. sort_distance_reward:只檢查生成數字與正確答案「相同位置」的數量。但這可能會讓明明排得很差的模型,只因為剛好矇對一個位置而拿到分數。
  2. sort_inclusion_ordering_reward:這是實驗中主要採用的函數。它將給分標準拆分為「包含度(只要把 prompt 裡的數字印出來就給分)」與「局部排序(相鄰數字若遞增就給分)」。這提供了非常平滑的學習訊號,即使是很糟的答案,只要模型「有在努力」,就能獲得較高的分數。(註:老師在課堂中有提到,這個函數其實藏有一個可以讓模型「作弊拿高分」的漏洞,例如輸出極端數字來滿足遞增條件,這反映了 RL 中獎勵常被模型駭入(hackable)的風險。)
def compute_reward(prompts: torch.Tensor, responses: torch.Tensor, reward_fn: Callable[[list[int], list[int]], float]) -> torch.Tensor:
    """
    Args:
        prompts (int[batch pos])
        responses (int[batch trial pos])
    Returns:
        rewards (float[batch trial])
    """
    batch_size, num_responses, _ = responses.shape
    rewards = torch.empty(batch_size, num_responses, dtype=torch.float32)
    for i in range(batch_size):
        for j in range(num_responses):
            rewards[i, j] = reward_fn(prompts[i, :], responses[i, j, :])
    return rewards

def sort_distance_reward(prompt: list[int], response: list[int]) -> float:
    """
    Return how close response is to ground_truth = sorted(prompt).
    In particular, compute number of positions where the response matches the ground truth.
    """
    assert len(prompt) == len(response)
    ground_truth = sorted(prompt)
    return sum(1 for x, y in zip(response, ground_truth) if x == y)

def sort_inclusion_ordering_reward(prompt: list[int], response: list[int]) -> float:
    """
    Return how close response is to ground_truth = sorted(prompt).
    """
    assert len(prompt) == len(response)
    
    # Give one point for each token in the prompt that shows up in the response
    inclusion_reward = sum(1 for x in prompt if x in response)  # @inspect inclusion_reward
    
    # Give one point for each adjacent pair in response that's sorted
    ordering_reward = sum(1 for x, y in zip(response, response[1:]) if x <= y)
    
    return inclusion_reward + ordering_reward

2. Model Architecture

真正的語言模型在生成時是「自迴歸(Autoregressive)」的,必須依賴一個接一個的迴圈來生成字詞,這在程式實作上會變得非常冗長且複雜。
這裡範例更乾淨、並能利用簡單的矩陣運算展示原理,特意改用非自迴歸的架構。模型會一次性將 prompt 編碼,並透過 encode_weights 壓縮位置資訊,接著利用 decode_weights 對回應的「每一個位置獨立進行解碼」。這避開了複雜的迴圈,但要注意這並非真實 LLM 預設的做法(除非是在 speculative decoding 的特殊情境中)。

class Model(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, prompt_length: int, response_length: int):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # For each position, we have a matrix for encoding and a matrix for decoding
        self.encode_weights = nn.Parameter(torch.randn(prompt_length, embedding_dim, embedding_dim) / math.sqrt(embedding_dim))
        self.decode_weights = nn.Parameter(torch.randn(response_length, embedding_dim, embedding_dim) / math.sqrt(embedding_dim))
    
    def forward(self, prompts: torch.Tensor) -> torch.Tensor:
        """
        Args:
            prompts: int[batch pos]
        Returns:
            logits: float[batch pos vocab]
        """
        # Embed the prompts
        embeddings = self.embedding(prompts)   # [batch pos dim]
        
        # Transform using per prompt position matrix, collapse into one vector
        encoded = einsum(embeddings, self.encode_weights, "batch pos dim1, pos dim1 dim2 -> batch dim2")
        
        # Turn into one vector per response position
        decoded = einsum(encoded, self.decode_weights, "batch dim2, pos dim2 dim1 -> batch pos dim1")
        
        # Convert to logits (input and output share embeddings)
        logits = einsum(decoded, self.embedding.weight, "batch pos dim1, vocab dim1 -> batch pos vocab")
        
        return logits

3. Generate Responses & Compute Log Probs

這是體現 GRPO 演算法優勢的關鍵步驟。傳統的 RL(如讓機器人走路)每次推演面對的狀態都不同;但語言模型有一個獨特的優勢:給定同一個 prompt,它可以輕易生成多個不同的回應,形成一個自然的「群組結構(group structure)
generate_responses 中,我們針對同一個 logits 分布,利用 torch.multinomial 進行多次獨立抽樣
。而在 compute_log_probs 中,我們則是將模型預測的 logits 取 log softmax,並利用剛剛抽樣出的 responses 作為 index 取出對應的對數機率(利用 gather 函數),作為後續計算梯度的分子。

def generate_responses(prompts: torch.Tensor, model: Model, num_responses: int) -> torch.Tensor:
    """
    Args:
        prompts (int[batch pos])
    Returns:
        generated responses: int[batch trial pos]
        
    Example (batch_size = 3, prompt_length = 3, num_responses = 2, response_length = 4)
    p1 p1 p1 r1 r1 r1 r1
             r2 r2 r2 r2
    p2 p2 p2 r3 r3 r3 r3
             r4 r4 r4 r4
    p3 p3 p3 r5 r5 r5 r5
             r6 r6 r6 r6
    """
    logits = model(prompts)  # [batch pos vocab]
    batch_size = prompts.shape[0]
    
    # Sample num_responses (independently) for each [batch pos]
    flattened_logits = rearrange(logits, "batch pos vocab -> (batch pos) vocab")
    flattened_responses = torch.multinomial(softmax(flattened_logits, dim=-1), num_samples=num_responses, replacement=True)  # [batch pos trial]
    responses = rearrange(flattened_responses, "(batch pos) trial -> batch trial pos", batch=batch_size)
    return responses

def compute_log_probs(prompts: torch.Tensor, responses: torch.Tensor, model: Model) -> torch.Tensor:
    """
    Args:
        prompts (int[batch pos])
        responses (int[batch trial pos])
    Returns:
        log_probs (float[batch trial pos]) under the model
    """
    # Compute log prob of responses under model
    logits = model(prompts)  # [batch pos vocab]
    log_probs = F.log_softmax(logits, dim=-1)  # [batch pos vocab]
    
    # Replicate to align with responses
    num_responses = responses.shape[1]
    log_probs = repeat(log_probs, "batch pos vocab -> batch trial pos vocab", trial=num_responses)  # [batch trial pos vocab]
    
    # Index into log_probs using responses
    log_probs = log_probs.gather(dim=-1, index=responses.unsqueeze(-1)).squeeze(-1)  # [batch trial pos]
    
    return log_probs

4. Baseline (Deltas), Freezing Parameters & Loss

這個區塊包含了 policy gradient 最核心的數學直覺與工程細節:

  1. Deltas & Centered Rewards
    如果直接用原始獎勵(如 9 分和 2 分)去更新,模型會盲目地推升分數高的那邊,即使那不是最佳解(high variance)。透過 centered_rewards 扣除群組的「平均獎勵」,我們就建立了 GRPO 的 baseline。
    • Ex:如果某次生成的 5 個回應全部都拿到 5 分(表現一樣好或一樣爛),center 後所有的 deltas 都會變成 0。這意味著 model 在此次樣本中不更新,因為沒有相對較好的答案可以學習,這能有效避免無效的梯度雜訊。
    • 另外一個有趣的變體是 max_rewards,它只保留最高分,其餘全設為零。這是為了解決模型為貪圖「部分分數(Low-hanging fruit)」而陷入局部最佳解的策略。
  2. Freezing Parameters
    在計算 PPO/GRPO 的更新量時,我們需要計算新策略與舊策略的機率比值:\text{Ratio} = p / p_{old}。這裡特別寫了 freezing_parameters 函式來展示一個常見錯誤:如果你不把 p_{old} 包在 torch.no_grad() 裡面,PyTorch 在反向傳播時會認為分子分母都是同一個會變動的參數。因為 p / p = 1,常數的微分為 0,這會導致更新的梯度變成 0。因此,舊模型的計算必須明確被視為常數。
  3. Clipped Loss
    我們利用 torch.clamp 將更新的 ratio 強制限制在 [1-\epsilon, 1+\epsilon] 之間。這是為了確保我們不會因為單次抽樣到極高獎勵的答案,就過度劇烈地修改模型參數,進而維持訓練的穩定性。
def compute_deltas(rewards: torch.Tensor, mode: str) -> torch.Tensor:  # @inspect rewards
    """
    Args:
        rewards (float[batch trial])
    Returns:
        deltas (float[batch trial]) which are advantage-like quantities for updating
    """
    if mode == "rewards":
        return rewards
    
    if mode == "centered_rewards":
        # Compute mean over all the responses (trial) for each prompt (batch)
        mean_rewards = rewards.mean(dim=-1, keepdim=True)  # @inspect mean_rewards
        centered_rewards = rewards - mean_rewards  # @inspect centered_rewards
        return centered_rewards
    
    if mode == "normalized_rewards":
        mean_rewards = rewards.mean(dim=-1, keepdim=True)  # @inspect mean_rewards
        std_rewards = rewards.std(dim=-1, keepdim=True)  # @inspect std_rewards
        centered_rewards = rewards - mean_rewards  # @inspect centered_rewards
        normalized_rewards = centered_rewards / (std_rewards + 1e-5)  # @inspect normalized_rewards
        return normalized_rewards
    
    if mode == "max_rewards":
        # Zero out any reward that isn't the maximum for each batch
        max_rewards = rewards.max(dim=-1, keepdim=True)[0]
        max_rewards = torch.where(rewards == max_rewards, rewards, torch.zeros_like(rewards))
        return max_rewards
    
    raise ValueError(f"Unknown mode: {mode}")

def compute_loss(log_probs: torch.Tensor, deltas: torch.Tensor, mode: str, old_log_probs: torch.Tensor | None = None) -> torch.Tensor:
    if mode == "naive":
        return -einsum(log_probs, deltas, "batch trial pos, batch trial -> batch trial pos").mean()
    
    if mode == "unclipped":
        ratios = torch.exp(log_probs - old_log_probs)  # [batch trial]
        return -einsum(ratios, deltas, "batch trial pos, batch trial -> batch trial pos").mean()
    
    if mode == "clipped":
        epsilon = 0.01
        unclipped_ratios = torch.exp(log_probs - old_log_probs)  # [batch trial]
        unclipped = einsum(unclipped_ratios, deltas, "batch trial pos, batch trial -> batch trial pos")
        
        clipped_ratios = torch.clamp(unclipped_ratios, min=1 - epsilon, max=1 + epsilon)
        clipped = einsum(clipped_ratios, deltas, "batch trial pos, batch trial -> batch trial pos")
        return -torch.minimum(unclipped, clipped).mean()
    
    raise ValueError(f"Unknown mode: {mode}")

def freezing_parameters():
    # Motivation: in GRPO you'll see ratios: p(a | s) / p_old(a | s)
    # When you're optimizing, it is important to freeze and not differentiate through p_old
    w = torch.tensor( 2., requires_grad= True)
    p = torch.nn.Sigmoid()(w)
    p_old = torch.nn.Sigmoid()(w)
    ratio = p / p_old
    ratio.backward()
    grad = w.grad # @inspect grad

    # Do it properly:
    w = torch.tensor( 2., requires_grad= True)
    p = torch.nn.Sigmoid()(w)
    with torch.no_grad(): # Important: treat p_old as a constant!
        p_old = torch.nn.Sigmoid()(w)
    ratio = p / p_old
    ratio.backward()
    grad = w.grad # @inspect grad

5. KL Penalty

當我們利用 RL 來教導語言模型新技能時,模型為了拿到最高分,可能會不擇手段,甚至喪失原本生成流暢語言的能力(例如為了遞增獎勵只會輸出亂碼)。為了防止這點,我們需要加入 KL penalty,將它拉回「參考模型(reference model)」的分佈。
在數學實作上,這邊介紹了一個 variance reduction 的技巧:與其單純計算 -\log(q/p),我們改用 q/p - \log(q/p) - 1。為什麼可以這樣寫?因為在統計期望值中,E[q/p] = 1,所以多加這兩項互相抵銷,並不會改變期望值的無偏性(unbiased estimate),但這在神經網路估計時,能大幅降低梯度更新時的變動雜訊。

def compute_kl_penalty(log_probs: torch.Tensor, ref_log_probs: torch.Tensor) -> torch.Tensor:
    """
    Compute an estimate of KL(model | ref_model), where the models are given by:
        log_probs [batch trial pos vocab]
        ref_log_probs [batch trial pos vocab]
    Use the estimate:
        KL(p || q) = E_p[q/p - log(q/p) - 1]
    """
    return (torch.exp(ref_log_probs - log_probs) - (ref_log_probs - log_probs) - 1).sum(dim=-1).mean()

6. Training Loop & System Complexity

  1. 管理多個模型:在單一個訓練迴圈中,我們不僅要維護「當前訓練的策略模型(policy)」,還要維護為了計算 KL 的「參考模型(reference model)」,以及計算比例用的「舊策略模型(old model)」。這對 VRAM 的消耗極大。實務上對於 old_model,我們通常不需儲存完整的模型權重,只需要將推論時算出來的 log_probs 暫存下來重複使用即可,這能省下大量記憶體。
  2. 雙層迴圈設計:inference 是非常昂貴的操作。因此外層迴圈(epoch)負責生成回應與計算獎勵;內層迴圈(step)則會針對同一批生成的回應,連續進行多次的梯度更新,最大化樣本使用效率。
def run_policy_gradient(num_epochs: int = 100,
                        num_steps_per_epoch: int = 10,
                        compute_ref_model_period: int = 10,
                        num_responses: int = 10,
                        deltas_mode: str = "rewards",
                        loss_mode: str = "naive",
                        kl_penalty: float = 0.0,
                        reward_fn: Callable[[list[int], list[int]], float] = sort_inclusion_ordering_reward,
                        use_cache: bool = False) -> tuple[str, str]:
    """Train a model using policy gradient.
    Return:
    - Path to the image of the learning curve.
    - Path to the log file
    """
    torch.manual_seed(5)
    
    image_path = f"var/policy_gradient_{deltas_mode}_{loss_mode}.png"
    log_path = f"var/policy_gradient_{deltas_mode}_{loss_mode}.txt"
    
    # Already ran, just cache it
    if use_cache and os.path.exists(image_path) and os.path.exists(log_path):
        return image_path, log_path
    
    # Define the data
    prompts = torch.tensor([[1, 0, 2], [3, 2, 4], [1, 2, 3]])
    vocab_size = prompts.max() + 1
    prompt_length = response_length = prompts.shape[1]
    
    model = Model(vocab_size=vocab_size, embedding_dim=10, prompt_length=prompt_length, response_length=response_length)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    records = []
    ref_log_probs = None
    ref_model = None
    old_log_probs = None
    
    if use_cache:
        out = open(log_path, "w")
    else:
        out = sys.stdout
    
    for epoch in tqdm(range(num_epochs), desc="epoch"):
        # If using KL penalty, need to get the reference model (freeze it every few epochs)
        if kl_penalty != 0:
            if epoch % compute_ref_model_period == 0:
                ref_model = model.clone()
        
        # Sample responses and evaluate their rewards
        responses = generate_responses(prompts=prompts, model=model, num_responses=num_responses)  # [batch trial pos]
        rewards = compute_reward(prompts=prompts, responses=responses, reward_fn=reward_fn)  # [batch trial]
        deltas = compute_deltas(rewards=rewards, mode=deltas_mode)  # [batch trial]
        
        if kl_penalty != 0:  # Compute under the reference model
            with torch.no_grad():
                ref_log_probs = compute_log_probs(prompts=prompts, responses=responses, model=ref_model)  # [batch trial]
        
        if loss_mode != "naive":  # Compute under the current model (but freeze while we do the inner steps)
            with torch.no_grad():
                old_log_probs = compute_log_probs(prompts=prompts, responses=responses, model=model)  # [batch trial]
        
        # Take a number of steps given the responses
        for step in range(num_steps_per_epoch):
            log_probs = compute_log_probs(prompts=prompts, responses=responses, model=model)  # [batch trial]
            loss = compute_loss(log_probs=log_probs, deltas=deltas, mode=loss_mode, old_log_probs=old_log_probs)  # @inspect loss
            if kl_penalty != 0:
                loss += kl_penalty * compute_kl_penalty(log_probs=log_probs, ref_log_probs=ref_log_probs)
            
            # Print information
            print_information(epoch=epoch, step=step, loss=loss, prompts=prompts, rewards=rewards, responses=responses, log_probs=log_probs, deltas=deltas, out=out)
            global_step = epoch * num_steps_per_epoch + step
            records.append({"epoch": epoch, "step": global_step, "loss": loss.item(), "mean_reward": rewards.mean().item()})
            
            # Backprop and update parameters
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    if use_cache:
        out.close()
    
    if use_cache:
        # Plot step versus loss and reward in two subplots
        steps = [r["step"] for r in records]
        losses = [r["loss"] for r in records]
        rewards = [r["mean_reward"] for r in records]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        # Loss subplot
        ax1.plot(steps, losses)
        ax1.set_xlabel("Step")
        ax1.set_ylabel("Train Loss")
        ax1.set_title("Train Loss")
        
        # Reward subplot
        ax2.plot(steps, rewards)
        ax2.set_xlabel("Step")
        ax2.set_ylabel("Mean Reward")
        ax2.set_title("Mean Reward")
        
        plt.tight_layout()
        plt.savefig(image_path)
        plt.close()
    return image_path, log_path

7. Experiment Observations & Insights

跑完訓練流程後,有幾個非常關鍵的觀察,可以完美詮釋強化學習的特性與痛點:

  1. Raw Rewards vs. Centered Rewards
    • Raw rewards 成效極差:如果單純使用原始分數來更新,模型幾乎學不會排序,只會產出一些跟 prompt 相關但未排序的數字。
    • 置中的威力:改用 centered rewards 後,表現有顯著提升。其背後的機制是:次佳的回應在扣除平均後會得到負的梯度更新,從而被推離;而如果該 batch 所有回應的表現都一樣(例如全拿 3 分),置中後 delta 為 0,模型就不會進行無意義的更新。儘管如此,模型依然很容易陷入 local optima。
  2. 標準化獎勵的侷限性與「長度偏見」
    • 實驗發現,在置中之後再除以標準差(normalized rewards),對於這個任務的表現並沒有帶來太大的差異
    • 更有趣的是,近期的相關研究(例如 Dr. GRPO)指出,執行這種標準化甚至可能會引發「長度偏見(length bias)」,也就是模型可能會學到「只要把答案寫得很長(即使是錯的)」就能駭入 GRPO 的最佳化機制來獲得好處。因此,有些最新的 GRPO 變體甚至刻意不使用標準化。
  3. Partial Credit 是一把雙面刃
    • 雖然前面提到,必須給予部分分數(partial credit)才能避免 sparse rewards 造成的「零梯度」死局,但這同時也帶來了副作用。
    • 可以觀察到,模型有時候會「貪圖好拿的部分分數」。例如,它發現只要把數字印出來、或者隨便湊幾個遞增的數字就能拿到 4 分,就安於現狀,不再努力去把整個數列完全排對。這凸顯了在 RL 中,設計一個「無法被模型駭入」的 reward 是一項極大的挑戰
  4. The Loss Curve is a Lie
    • 如果觀察 RL 的 training loss,有時會發現它看起來非常糟糕或不穩定,但同時期的「平均獎勵(mean reward)」卻在穩定上升。
    • 為什麼會這樣?因為在 RL 中,模型自己生成的訓練資料集(responses)隨著時間不斷在改變。這就像是你在一個基準點不斷飄移的動態目標上計算 loss,因此 loss 的絕對高低失去了意義。在強化學習中,唯一能真實反映模型是否進步的指標,只有「平均獎勵」
  5. RL 系統極度敏感且脆弱
    • 總結來說,即使是「排序 3 個數字」這麼微不足道的任務,模型都很容易卡在次佳狀態。這證明了 RL 絕對不是一件簡單的事,其 hyperparameters 的調校困難度遠高於監督式微調(SFT)。
1 Like