[RSCH] 7 分鐘閱讀OraCore 編輯部

如何防止 LLM 微調災難性遺忘

用 Anchored Weight Decay 在 LLM 微調時降低舊任務漂移,保住原有能力並檢查模型是否回復。

分享 LinkedIn
如何防止 LLM 微調災難性遺忘

用 Anchored Weight Decay 在 LLM 微調時降低舊任務漂移,保住原有能力並檢查模型是否回復。

這篇給做 LLM 微調、持續後訓練與研究評估的 ML 工程師和研究者看。照著做完,你會得到一套可重複的流程,用來觀察模型是否在新增能力時遺忘舊能力。

你也會得到一個可直接套用的訓練範本,包含基準評估、短跑式 ES 試驗、Anchored Weight Decay、以及最後的 checkpoint 挑選規則。

開始之前

訂閱 AI 趨勢週報

每週精選模型發布、工具應用與深度分析,直送信箱。不定期,不騷擾。

不會寄垃圾信,隨時可取消。

  • 可用的 LLM 微調環境,例如 GPU 工作站或雲端訓練工作。
  • Python 3.10+。
  • PyTorch 2.1+。
  • 至少一個舊任務評估集,例如 HellaSwag 或自建驗證集。
  • Evolution Strategies 實作或研究程式庫。
  • GitHub 參考倉庫與來源文章:GitHubCognizant

Step 1: 鎖定舊任務基準分數

先做這一步的目的,是建立一份可對照的舊能力清單。你要在任何新微調開始前,先量測模型在保留任務上的表現,這樣後面的分數變化才有意義。

如何防止 LLM 微調災難性遺忘

把每個任務的指標記下來,像是 accuracy、exact match 或 pass rate,並和 base checkpoint 名稱一起存檔。若你用 HellaSwag,請固定 prompt 格式、解碼參數與 evaluation seed。

驗收時,你應該看到多次重跑的分數落在很小的誤差範圍內。若波動很大,先修正評估流程,再碰模型權重。

python eval_prior_tasks.py \
  --model base_checkpoint \
  --tasks hellaswag,heldout_set \
  --seed 42 \
  --output baseline_metrics.json

你應該看到一份名為 baseline_metrics.json 的基準結果檔,裡面每個任務都有固定分數。

Step 2: 跑一個短版 ES 微調試驗

這一步的目的,是先拿到一條完整訓練軌跡,而不是只看最後一個 checkpoint。先用小規模 Evolution Strategies 跑新任務,觀察模型是否在早期快速漂移。

如何防止 LLM 微調災難性遺忘
python train_es.py \
  --base-model meta-llama/Llama-3.1-8B-Instruct \
  --target-task countdown \
  --population-size 64 \
  --iterations 300 \
  --eval-every 10 \
  --save-checkpoints true

驗收時,你應該看到每隔固定步數就輸出一次 checkpoint,並且能沿著時間序列比對舊任務分數。

如果前期分數先掉、後面又回升,這代表的是暫時漂移,不一定是永久遺忘。

Step 3: 加入 Anchored Weight Decay

這一步的目的,是把更新限制在離起點不要太遠的範圍內。Anchored Weight Decay 會保留一份初始權重副本,並在更新時加入朝向錨點回拉的懲罰項。

實作上,你要把初始權重凍結成 anchor,然後在 ES 更新規則裡加上 decay term。這個設計的重點,是讓模型優先學新任務,但避免參數走得太遠而洗掉舊技能。

驗收時,你應該看到舊任務分數的上下震盪變小。若新任務仍有進步,而舊任務不再大幅崩落,代表錨定策略開始生效。

class AnchoredWeightDecay:
    def __init__(self, anchor_state_dict, lambda_awd=0.01):
        self.anchor = anchor_state_dict
        self.lambda_awd = lambda_awd

    def penalty(self, model):
        total = 0.0
        for name, param in model.named_parameters():
            total += (param - self.anchor[name]).pow(2).sum()
        return self.lambda_awd * total

你應該看到訓練損失多出一個 AWD 懲罰項,且舊任務曲線不再一路下滑。

Step 4: 用同一套任務比較 ES 與 GRPO

這一步的目的,是確認遺忘現象是不是只出現在 ES,還是更廣泛的後訓練問題。請用同一個 base model、同一組 prompt 與同一個 benchmark 排程,重跑 GRPO。

要同時記錄新任務進展與舊任務平均分數,重點不只是最後結果,而是訓練過程中的曲線形狀。

驗收時,你應該看到 GRPO 也可能讓舊任務掉分,即使漂移型態和 ES 不同。若兩者都會傷到舊能力,問題通常不是換方法就能解決,而是要加強保留機制。

python train_grpo.py \
  --base-model meta-llama/Llama-3.1-8B-Instruct \
  --target-task countdown \
  --eval-suite hellaswag,heldout_set \
  --save-checkpoints true

你應該看到一份對照表,能直接比較 ES 與 GRPO 在新任務與舊任務上的曲線差異。

Step 5: 用漂移檢查挑出最佳 checkpoint

這一步的目的,是建立可上線的挑選規則。請選出同時滿足新任務門檻、且舊任務損失在可接受範圍內的模型版本。

在正式保留前,重新跑完整舊任務套件,並和原始 baseline 對照。只有當模型保住你在意的能力,而且漂移在容忍範圍內時,才把這個 checkpoint 視為可用版本。

驗收時,你應該看到最終 checkpoint 的舊任務分數接近 baseline,而不是永久掉檔。這就是模型有適應新任務、但沒有發生災難性遺忘的訊號。

python select_checkpoint.py \
  --metrics-dir runs/es_awd \
  --target-threshold 0.80 \
  --max-prior-drop 0.03 \
  --output promoted_checkpoint.txt

你應該看到一個名為 promoted_checkpoint.txt 的產出,裡面只有最後被升級的 checkpoint 路徑。

指標基準/優化前結果/優化後
HellaSwag 舊任務準確率Baseline checkpoint前期下降 8%,後期回升接近 baseline
訓練解讀假設是不可逆遺忘觀察到的是暫時性參數漂移
保留策略Plain ES updateES + Anchored Weight Decay

常見錯誤

  • 只看最後一個 checkpoint。修法:在訓練中途持續記錄舊任務分數,才能抓到暫時漂移與回復。
  • 每次跑評估都改 prompt 或 seed。修法:固定 prompt、解碼設定與隨機種子,讓 baseline 可比較。
  • 把遺忘只怪在 ES。修法:先用 GRPO 或另一種後訓練方法交叉檢查,再判定問題來源。

接下來可以看什麼

如果你已經能穩定保留舊能力,下一步可以把同一套流程擴到多任務持續微調,並為不同領域調整 AWD 強度,再加上 checkpoint 自動回滾規則。