如何防止 LLM 微調災難性遺忘
用 Anchored Weight Decay 在 LLM 微調時降低舊任務漂移,保住原有能力並檢查模型是否回復。

用 Anchored Weight Decay 在 LLM 微調時降低舊任務漂移,保住原有能力並檢查模型是否回復。
這篇給做 LLM 微調、持續後訓練與研究評估的 ML 工程師和研究者看。照著做完,你會得到一套可重複的流程,用來觀察模型是否在新增能力時遺忘舊能力。
你也會得到一個可直接套用的訓練範本,包含基準評估、短跑式 ES 試驗、Anchored Weight Decay、以及最後的 checkpoint 挑選規則。
開始之前
訂閱 AI 趨勢週報
每週精選模型發布、工具應用與深度分析,直送信箱。不定期,不騷擾。
不會寄垃圾信,隨時可取消。
- 可用的 LLM 微調環境,例如 GPU 工作站或雲端訓練工作。
- Python 3.10+。
- PyTorch 2.1+。
- 至少一個舊任務評估集,例如 HellaSwag 或自建驗證集。
- Evolution Strategies 實作或研究程式庫。
- GitHub 參考倉庫與來源文章:GitHub、Cognizant。
Step 1: 鎖定舊任務基準分數
先做這一步的目的,是建立一份可對照的舊能力清單。你要在任何新微調開始前,先量測模型在保留任務上的表現,這樣後面的分數變化才有意義。

把每個任務的指標記下來,像是 accuracy、exact match 或 pass rate,並和 base checkpoint 名稱一起存檔。若你用 HellaSwag,請固定 prompt 格式、解碼參數與 evaluation seed。
驗收時,你應該看到多次重跑的分數落在很小的誤差範圍內。若波動很大,先修正評估流程,再碰模型權重。
python eval_prior_tasks.py \
--model base_checkpoint \
--tasks hellaswag,heldout_set \
--seed 42 \
--output baseline_metrics.json你應該看到一份名為 baseline_metrics.json 的基準結果檔,裡面每個任務都有固定分數。
Step 2: 跑一個短版 ES 微調試驗
這一步的目的,是先拿到一條完整訓練軌跡,而不是只看最後一個 checkpoint。先用小規模 Evolution Strategies 跑新任務,觀察模型是否在早期快速漂移。

python train_es.py \
--base-model meta-llama/Llama-3.1-8B-Instruct \
--target-task countdown \
--population-size 64 \
--iterations 300 \
--eval-every 10 \
--save-checkpoints true驗收時,你應該看到每隔固定步數就輸出一次 checkpoint,並且能沿著時間序列比對舊任務分數。
如果前期分數先掉、後面又回升,這代表的是暫時漂移,不一定是永久遺忘。
Step 3: 加入 Anchored Weight Decay
這一步的目的,是把更新限制在離起點不要太遠的範圍內。Anchored Weight Decay 會保留一份初始權重副本,並在更新時加入朝向錨點回拉的懲罰項。
實作上,你要把初始權重凍結成 anchor,然後在 ES 更新規則裡加上 decay term。這個設計的重點,是讓模型優先學新任務,但避免參數走得太遠而洗掉舊技能。
驗收時,你應該看到舊任務分數的上下震盪變小。若新任務仍有進步,而舊任務不再大幅崩落,代表錨定策略開始生效。
class AnchoredWeightDecay:
def __init__(self, anchor_state_dict, lambda_awd=0.01):
self.anchor = anchor_state_dict
self.lambda_awd = lambda_awd
def penalty(self, model):
total = 0.0
for name, param in model.named_parameters():
total += (param - self.anchor[name]).pow(2).sum()
return self.lambda_awd * total你應該看到訓練損失多出一個 AWD 懲罰項,且舊任務曲線不再一路下滑。
Step 4: 用同一套任務比較 ES 與 GRPO
這一步的目的,是確認遺忘現象是不是只出現在 ES,還是更廣泛的後訓練問題。請用同一個 base model、同一組 prompt 與同一個 benchmark 排程,重跑 GRPO。
要同時記錄新任務進展與舊任務平均分數,重點不只是最後結果,而是訓練過程中的曲線形狀。
驗收時,你應該看到 GRPO 也可能讓舊任務掉分,即使漂移型態和 ES 不同。若兩者都會傷到舊能力,問題通常不是換方法就能解決,而是要加強保留機制。
python train_grpo.py \
--base-model meta-llama/Llama-3.1-8B-Instruct \
--target-task countdown \
--eval-suite hellaswag,heldout_set \
--save-checkpoints true你應該看到一份對照表,能直接比較 ES 與 GRPO 在新任務與舊任務上的曲線差異。
Step 5: 用漂移檢查挑出最佳 checkpoint
這一步的目的,是建立可上線的挑選規則。請選出同時滿足新任務門檻、且舊任務損失在可接受範圍內的模型版本。
在正式保留前,重新跑完整舊任務套件,並和原始 baseline 對照。只有當模型保住你在意的能力,而且漂移在容忍範圍內時,才把這個 checkpoint 視為可用版本。
驗收時,你應該看到最終 checkpoint 的舊任務分數接近 baseline,而不是永久掉檔。這就是模型有適應新任務、但沒有發生災難性遺忘的訊號。
python select_checkpoint.py \
--metrics-dir runs/es_awd \
--target-threshold 0.80 \
--max-prior-drop 0.03 \
--output promoted_checkpoint.txt你應該看到一個名為 promoted_checkpoint.txt 的產出,裡面只有最後被升級的 checkpoint 路徑。
| 指標 | 基準/優化前 | 結果/優化後 |
|---|---|---|
| HellaSwag 舊任務準確率 | Baseline checkpoint | 前期下降 8%,後期回升接近 baseline |
| 訓練解讀 | 假設是不可逆遺忘 | 觀察到的是暫時性參數漂移 |
| 保留策略 | Plain ES update | ES + Anchored Weight Decay |
常見錯誤
- 只看最後一個 checkpoint。修法:在訓練中途持續記錄舊任務分數,才能抓到暫時漂移與回復。
- 每次跑評估都改 prompt 或 seed。修法:固定 prompt、解碼設定與隨機種子,讓 baseline 可比較。
- 把遺忘只怪在 ES。修法:先用 GRPO 或另一種後訓練方法交叉檢查,再判定問題來源。
接下來可以看什麼
如果你已經能穩定保留舊能力,下一步可以把同一套流程擴到多任務持續微調,並為不同領域調整 AWD 強度,再加上 checkpoint 自動回滾規則。