2 min read

LLM:近似on-policy数据抗遗忘(2)Iterative-SFT

LLM 系列导航

1 LLM:Function Call(1)从传统工具调用到函数驱动 2020-09-08
2 LLM:关注因果推断研究进展 2023-06-22
3 LLM:人机协作 2024-05-19
4 LLM:分层管理风险定性 2024-08-08
5 LLM:二阶段FN分层分析与模型提升空间测算 2024-09-12
6 LLM:二阶段FN分层分析与模型提升空间测算(2) 2024-09-12
7 LLM:二阶段FN分层分析与模型提升空间测算(3) 2024-09-12
8 LLM:二阶段FN分层分析与模型提升空间测算(4) 2024-09-12
9 LLM:推理不可复现的探索 2025-06-11
10 LLM:SFT 与 RL 的关系 2025-07-29
11 LLM:SFT 与 RL 的关系(理论修正与实践补充) 2025-07-30
12 LLM:SFT 与 RL 的关系(理论修正与实践补充II) 2025-07-31
13 LLM:低数据场景的决策树生成 2025-08-03
14 LLM:低数据场景的决策树生成(2)落地冷启动 2025-08-03
15 LLM:表格数据特征工程 2025-08-03
16 LLM:撰写大模型落地提示词的关键 2025-08-27
17 LLM:从 Prompt 设计到工作流落地 2025-08-28
18 LLM:RL's Razor 抗遗忘 2025-09-04
19 LLM:RL's Razor 抗遗忘(2) 2025-09-04
20 LLM:RL's Razor 抗遗忘(3)SFT 与 RL 的认知偏差及应用 2025-09-04
21 LLM:RL's Razor 抗遗忘(4)on-policy认知误区 2025-09-04
22 LLM:RL's Razor 抗遗忘(5)RL在线生成样本 2025-09-04
23 LLM:MuseGraph融合GNN与LLM的通用图挖掘新框架 2025-09-05
24 LLM:幻觉成因与解决 2025-09-08
25 LLM:Agent 逻辑与应用场景 2025-09-23
26 LLM:拆解大模型缩放定律失效的三重分解 2025-10-05
27 LLM:近似on-policy数据抗遗忘 2025-10-21
28 LLM:幻觉治理 2025-10-28
29 LLM:先验偏见(1)变量名带偏判断 2025-11-13
30 LLM:先验偏见(2)分层分析应对 2025-11-13
31 LLM:先验偏见(3)实验效度的场景化适配 2025-11-13
32 LLM:分层落地 2025-11-13
33 LLM:零样本在金融场景落地 2025-11-13
34 LLM:跨难度泛化的局限与量化 2025-11-26
35 LLM:重复提示词解锁非推理性能上限 2025-12-17
36 LLM:用失败样本提升指令遵循能力 2025-12-29
37 LLM:概率引导的高价值信号筛选 2026-01-14
38 LLM:低成本安全检测的级联方案 2026-01-16
39 LLM:定性编码的假阳性解决方案 2026-01-16
40 LLM:先验偏见(4)挑战与落地解决方案 2026-01-22
41 LLM:先验偏见(5)工程化方案 2026-01-22
42 LLM:SimRL(1)分层评估 2026-02-03
43 LLM:SimRL(2)理论逻辑与工程落地 2026-02-03

在大模型后训练领域,监督微调(SFT)长期陷在学新忘旧的困境中——为贴合固定标注数据,模型不仅要学习冗余信息,还需大幅调整参数,最终导致旧知识被覆盖。而Iterative-SFT的出现,以反向筛选样本的核心逻辑提供了优化路径,其迭代思路甚至能与经典的监督算法形成有趣对照。这篇内容将拆解Iterative-SFT的核心价值,用实操代码落地,并厘清它与传统SFT、XGBoost的本质差异。

一、核心矛盾:传统SFT的样本问题

传统SFT的困境根源,在于对off-policy离线数据的依赖和全盘接收的样本逻辑。这类数据由人工标注生成,与模型当前策略分布可能存在显著偏差,比如标注数据中包含非核心的表述习惯、冗余修饰语,模型为贴合数据分布,不得不将这些无关信息纳入参数学习。

更关键的是,SFT不做样本筛选——无论数据是否契合模型当前能力,都一股脑用于训练。这种数据主导、模型迁就的模式,必然导致参数大幅变动,进而引发灾难性遗忘。例如在MMLU常识任务微调后,模型对原有算术推理能力的保留率常低于60%(Shenfeld et al., 2025),而冗余信息的学习会进一步挤压旧知识的参数空间,加剧这一问题。

二、Iterative-SFT的优化路径:反向筛选样本的逻辑

Iterative-SFT的核心是将样本利用逻辑从全盘接收转为择优录用,核心流程可概括为每轮epoch初用当前模型生成样本→筛选正确响应→用有效样本微调。这种逻辑背后是对近似on-policy数据的依赖——数据由当前模型实时生成,天然贴合其能力分布,微调时无需大幅改动参数,从而实现学新不丢旧。

举个具体场景:用Qwen 2.5 1.5B模型做常识问答微调,每个prompt让模型生成5个响应,只保留答案匹配真值的样本。第一轮可能仅30%的生成样本有效,但微调后模型能力提升,第二轮有效样本率会增至50%以上——模型在“巩固已有能力”的基础上逐步扩展,既避免遗忘,又能稳步提升目标任务性能。

三、关键对比:Iterative-SFT与SFT、XGBoost的核心差异

很多人会将Iterative-SFT的迭代逻辑与XGBoost的集成学习思路类比,三者的核心差异集中在“样本筛选”和“优化目标”上,用表格可清晰呈现:

对比维度 传统SFT Iterative-SFT XGBoost(二分类场景)
样本来源 固定离线标注数据 当前模型实时生成的动态数据 固定离线标注数据
筛选逻辑 无筛选,全盘接收 选“当前模型能做对”的样本 选“当前模型做不对”的样本(残差大)
优化目标 快速贴合数据分布,忽视遗忘 抗遗忘+稳步提升目标任务性能 修正误差,提升预测准确率
参数更新方式 单模型参数大幅调整 单模型参数小幅迭代 加法模型,新增树修正误差

简单来说,传统SFT是硬啃所有数据,XGBoost是专攻错题,而Iterative-SFT是先吃透会做的题,再挑战难题——三种逻辑没有优劣,只是适配不同需求,Iterative-SFT的逻辑恰好解决了大模型抗遗忘的痛点。

四、实操落地:Iterative-SFT完整代码(PyTorch+Transformers)

以下代码以Qwen 2.5 1.5B模型为例,适配MMLU常识任务,包含样本生成、迭代训练、遗忘率评估全流程,可直接运行,GPU显存12GB即可支撑。


import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          Trainer, TrainingArguments, DataCollatorForLanguageModeling)
from datasets import Dataset
import evaluate

# 1. 基础配置(适配12GB GPU)
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
TOKENIZER.pad_token = TOKENIZER.eos_token  # 补全pad token
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEQ_LEN = 512
NUM_EPOCHS = 2  # 论文推荐轮次
NUM_RESPONSES = 5  # 每个prompt生成响应数
LEARNING_RATE = 1e-5  # 1.5B模型适配学习率

# 2. 模型初始化(显存优化)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,  # 混合精度训练
    device_map=DEVICE,
    gradient_checkpointing=True  # 节省显存
)

# 3. 基础Prompt数据加载(实际需扩展至10k+样本)
def load_base_data():
    sample_data = [
        {
            "prompt": "User: 下列属于质数的是?\n选项:A.1 B.2 C.4 D.6\n回答格式:The answer is: X\nAssistant:",
            "ground_truth": "B"
        },
        {
            "prompt": "User: 水的化学式是?\n选项:A.H2O B.CO2 C.O2 D.NaCl\n回答格式:The answer is: X\nAssistant:",
            "ground_truth": "A"
        }
    ]
    return Dataset.from_list(sample_data)

base_dataset = load_base_data()

# 4. 核心:生成近似on-policy数据
def gen_valid_data(model, tokenizer, dataset):
    valid_samples = []
    model.eval()
    with torch.no_grad():
        for item in dataset:
            prompt = item["prompt"]
            gt = item["ground_truth"]
            # 编码输入
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN).to(DEVICE)
            # 生成多个响应
            outputs = model.generate(
                **inputs, max_new_tokens=32, num_return_sequences=NUM_RESPONSES,
                temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id
            )
            # 筛选正确样本
            for output in outputs:
                response = tokenizer.decode(output, skip_special_tokens=True).split("Assistant:")[-1].strip()
                if f"The answer is: {gt}" in response:
                    valid_samples.append({"input_ids": output.cpu().numpy(), "labels": output.cpu().numpy()})
    return Dataset.from_list(valid_samples)

# 5. 遗忘率评估函数
def calc_forgetting(base_model, curr_model, non_target_data):
    acc_metric = evaluate.load("accuracy")
    def get_preds(model, data):
        preds = []
        model.eval()
        with torch.no_grad():
            for item in data:
                inputs = tokenizer(item["prompt"], return_tensors="pt").to(DEVICE)
                outputs = model.generate(**inputs, max_new_tokens=32, do_sample=False)
                pred = tokenizer.decode(outputs[0], skip_special_tokens=True).split("is:")[-1].strip()
                preds.append(pred)
        return preds
    base_preds = get_preds(base_model, non_target_data)
    curr_preds = get_preds(curr_model, non_target_data)
    base_acc = acc_metric.compute(predictions=base_preds, references=[d["ground_truth"] for d in non_target_data])["accuracy"]
    curr_acc = acc_metric.compute(predictions=curr_preds, references=[d["ground_truth"] for d in non_target_data])["accuracy"]
    return base_acc - curr_acc

# 6. 迭代训练主流程
def iterative_sft():
    # 保存基础模型用于遗忘率基准
    base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16, device_map=DEVICE)
    non_target_data = load_base_data()  # 实际替换为非目标任务数据(如MATH)
    curr_model = model
    
    for epoch in range(NUM_EPOCHS):
        print(f"\n=== 第{epoch+1}轮迭代 ===")
        # 生成当前轮次有效数据
        valid_data = gen_valid_data(curr_model, TOKENIZER, base_dataset)
        print(f"有效样本数:{len(valid_data)}")
        if not valid_data:
            continue
        
        # 配置训练参数
        train_args = TrainingArguments(
            output_dir=f"./ckpt/epoch_{epoch+1}",
            per_device_train_batch_size=8, learning_rate=LEARNING_RATE,
            num_train_epochs=1, logging_steps=5, save_strategy="epoch",
            bf16=True, remove_unused_columns=False
        )
        
        # 初始化Trainer
        data_collator = DataCollatorForLanguageModeling(tokenizer=TOKENIZER, mlm=False)
        trainer = Trainer(
            model=curr_model, args=train_args, train_dataset=valid_data, data_collator=data_collator
        )
        
        # 启动训练
        trainer.train()
        
        # 评估遗忘率
        forgetting_rate = calc_forgetting(base_model, curr_model, non_target_data)
        print(f"遗忘率:{forgetting_rate:.4f}")
    
    # 保存最终模型
    curr_model.save_pretrained("./final_iterative_sft_model")
    TOKENIZER.save_pretrained("./final_iterative_sft_model")
    print("\n模型保存完成")

# 7. 启动训练(首次运行需安装依赖:pip install torch transformers datasets evaluate)
if __name__ == "__main__":
    iterative_sft()

五、实用价值:低成本落地的关键建议

基于代码和对比分析,Iterative-SFT的落地无需复杂资源,核心注意三点:

  • 样本规模:示例用若干样本,实际需扩展至10k+(如MMLU训练集1.2万样本),否则有效样本不足,难以体现抗遗忘效果;

  • 非目标任务选择:评估遗忘率时,需用与目标任务无关的数据(如目标是常识问答,非目标用算术推理MATH数据集),避免数据分布重叠导致评估穿越;

参考文献

Chen, H., Razin, N., Narasimhan, K., & Chen, D. (2025). Retaining by doing: The role of on-policy data in mitigating forgetting. arXiv. https://arxiv.org/abs/2510.18874

Shenfeld, I., Pari, J., & Agrawal, P. (2025). RL’s Razor: Why online reinforcement learning forgets less. arXiv. https://arxiv.org/abs/2509.04259

LLM 系列导航

1 LLM:Function Call(1)从传统工具调用到函数驱动 2020-09-08
2 LLM:关注因果推断研究进展 2023-06-22
3 LLM:人机协作 2024-05-19
4 LLM:分层管理风险定性 2024-08-08
5 LLM:二阶段FN分层分析与模型提升空间测算 2024-09-12
6 LLM:二阶段FN分层分析与模型提升空间测算(2) 2024-09-12
7 LLM:二阶段FN分层分析与模型提升空间测算(3) 2024-09-12
8 LLM:二阶段FN分层分析与模型提升空间测算(4) 2024-09-12
9 LLM:推理不可复现的探索 2025-06-11
10 LLM:SFT 与 RL 的关系 2025-07-29
11 LLM:SFT 与 RL 的关系(理论修正与实践补充) 2025-07-30
12 LLM:SFT 与 RL 的关系(理论修正与实践补充II) 2025-07-31
13 LLM:低数据场景的决策树生成 2025-08-03
14 LLM:低数据场景的决策树生成(2)落地冷启动 2025-08-03
15 LLM:表格数据特征工程 2025-08-03
16 LLM:撰写大模型落地提示词的关键 2025-08-27
17 LLM:从 Prompt 设计到工作流落地 2025-08-28
18 LLM:RL's Razor 抗遗忘 2025-09-04
19 LLM:RL's Razor 抗遗忘(2) 2025-09-04
20 LLM:RL's Razor 抗遗忘(3)SFT 与 RL 的认知偏差及应用 2025-09-04
21 LLM:RL's Razor 抗遗忘(4)on-policy认知误区 2025-09-04
22 LLM:RL's Razor 抗遗忘(5)RL在线生成样本 2025-09-04
23 LLM:MuseGraph融合GNN与LLM的通用图挖掘新框架 2025-09-05
24 LLM:幻觉成因与解决 2025-09-08
25 LLM:Agent 逻辑与应用场景 2025-09-23
26 LLM:拆解大模型缩放定律失效的三重分解 2025-10-05
27 LLM:近似on-policy数据抗遗忘 2025-10-21
28 LLM:幻觉治理 2025-10-28
29 LLM:先验偏见(1)变量名带偏判断 2025-11-13
30 LLM:先验偏见(2)分层分析应对 2025-11-13
31 LLM:先验偏见(3)实验效度的场景化适配 2025-11-13
32 LLM:分层落地 2025-11-13
33 LLM:零样本在金融场景落地 2025-11-13
34 LLM:跨难度泛化的局限与量化 2025-11-26
35 LLM:重复提示词解锁非推理性能上限 2025-12-17
36 LLM:用失败样本提升指令遵循能力 2025-12-29
37 LLM:概率引导的高价值信号筛选 2026-01-14
38 LLM:低成本安全检测的级联方案 2026-01-16
39 LLM:定性编码的假阳性解决方案 2026-01-16
40 LLM:先验偏见(4)挑战与落地解决方案 2026-01-22
41 LLM:先验偏见(5)工程化方案 2026-01-22
42 LLM:SimRL(1)分层评估 2026-02-03
43 LLM:SimRL(2)理论逻辑与工程落地 2026-02-03