在大模型后训练领域,监督微调(SFT)长期陷在学新忘旧的困境中——为贴合固定标注数据,模型不仅要学习冗余信息,还需大幅调整参数,最终导致旧知识被覆盖。而Iterative-SFT的出现,以反向筛选样本的核心逻辑提供了优化路径,其迭代思路甚至能与经典的监督算法形成有趣对照。这篇内容将拆解Iterative-SFT的核心价值,用实操代码落地,并厘清它与传统SFT、XGBoost的本质差异。
一、核心矛盾:传统SFT的样本问题
传统SFT的困境根源,在于对off-policy离线数据的依赖和全盘接收的样本逻辑。这类数据由人工标注生成,与模型当前策略分布可能存在显著偏差,比如标注数据中包含非核心的表述习惯、冗余修饰语,模型为贴合数据分布,不得不将这些无关信息纳入参数学习。
更关键的是,SFT不做样本筛选——无论数据是否契合模型当前能力,都一股脑用于训练。这种数据主导、模型迁就的模式,必然导致参数大幅变动,进而引发灾难性遗忘。例如在MMLU常识任务微调后,模型对原有算术推理能力的保留率常低于60%(Shenfeld et al., 2025),而冗余信息的学习会进一步挤压旧知识的参数空间,加剧这一问题。
二、Iterative-SFT的优化路径:反向筛选样本的逻辑
Iterative-SFT的核心是将样本利用逻辑从全盘接收转为择优录用,核心流程可概括为每轮epoch初用当前模型生成样本→筛选正确响应→用有效样本微调。这种逻辑背后是对近似on-policy数据的依赖——数据由当前模型实时生成,天然贴合其能力分布,微调时无需大幅改动参数,从而实现学新不丢旧。
举个具体场景:用Qwen 2.5 1.5B模型做常识问答微调,每个prompt让模型生成5个响应,只保留答案匹配真值的样本。第一轮可能仅30%的生成样本有效,但微调后模型能力提升,第二轮有效样本率会增至50%以上——模型在“巩固已有能力”的基础上逐步扩展,既避免遗忘,又能稳步提升目标任务性能。
三、关键对比:Iterative-SFT与SFT、XGBoost的核心差异
很多人会将Iterative-SFT的迭代逻辑与XGBoost的集成学习思路类比,三者的核心差异集中在“样本筛选”和“优化目标”上,用表格可清晰呈现:
| 对比维度 | 传统SFT | Iterative-SFT | XGBoost(二分类场景) |
|---|---|---|---|
| 样本来源 | 固定离线标注数据 | 当前模型实时生成的动态数据 | 固定离线标注数据 |
| 筛选逻辑 | 无筛选,全盘接收 | 选“当前模型能做对”的样本 | 选“当前模型做不对”的样本(残差大) |
| 优化目标 | 快速贴合数据分布,忽视遗忘 | 抗遗忘+稳步提升目标任务性能 | 修正误差,提升预测准确率 |
| 参数更新方式 | 单模型参数大幅调整 | 单模型参数小幅迭代 | 加法模型,新增树修正误差 |
简单来说,传统SFT是硬啃所有数据,XGBoost是专攻错题,而Iterative-SFT是先吃透会做的题,再挑战难题——三种逻辑没有优劣,只是适配不同需求,Iterative-SFT的逻辑恰好解决了大模型抗遗忘的痛点。
四、实操落地:Iterative-SFT完整代码(PyTorch+Transformers)
以下代码以Qwen 2.5 1.5B模型为例,适配MMLU常识任务,包含样本生成、迭代训练、遗忘率评估全流程,可直接运行,GPU显存12GB即可支撑。
import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
Trainer, TrainingArguments, DataCollatorForLanguageModeling)
from datasets import Dataset
import evaluate
# 1. 基础配置(适配12GB GPU)
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
TOKENIZER.pad_token = TOKENIZER.eos_token # 补全pad token
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEQ_LEN = 512
NUM_EPOCHS = 2 # 论文推荐轮次
NUM_RESPONSES = 5 # 每个prompt生成响应数
LEARNING_RATE = 1e-5 # 1.5B模型适配学习率
# 2. 模型初始化(显存优化)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16, # 混合精度训练
device_map=DEVICE,
gradient_checkpointing=True # 节省显存
)
# 3. 基础Prompt数据加载(实际需扩展至10k+样本)
def load_base_data():
sample_data = [
{
"prompt": "User: 下列属于质数的是?\n选项:A.1 B.2 C.4 D.6\n回答格式:The answer is: X\nAssistant:",
"ground_truth": "B"
},
{
"prompt": "User: 水的化学式是?\n选项:A.H2O B.CO2 C.O2 D.NaCl\n回答格式:The answer is: X\nAssistant:",
"ground_truth": "A"
}
]
return Dataset.from_list(sample_data)
base_dataset = load_base_data()
# 4. 核心:生成近似on-policy数据
def gen_valid_data(model, tokenizer, dataset):
valid_samples = []
model.eval()
with torch.no_grad():
for item in dataset:
prompt = item["prompt"]
gt = item["ground_truth"]
# 编码输入
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN).to(DEVICE)
# 生成多个响应
outputs = model.generate(
**inputs, max_new_tokens=32, num_return_sequences=NUM_RESPONSES,
temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id
)
# 筛选正确样本
for output in outputs:
response = tokenizer.decode(output, skip_special_tokens=True).split("Assistant:")[-1].strip()
if f"The answer is: {gt}" in response:
valid_samples.append({"input_ids": output.cpu().numpy(), "labels": output.cpu().numpy()})
return Dataset.from_list(valid_samples)
# 5. 遗忘率评估函数
def calc_forgetting(base_model, curr_model, non_target_data):
acc_metric = evaluate.load("accuracy")
def get_preds(model, data):
preds = []
model.eval()
with torch.no_grad():
for item in data:
inputs = tokenizer(item["prompt"], return_tensors="pt").to(DEVICE)
outputs = model.generate(**inputs, max_new_tokens=32, do_sample=False)
pred = tokenizer.decode(outputs[0], skip_special_tokens=True).split("is:")[-1].strip()
preds.append(pred)
return preds
base_preds = get_preds(base_model, non_target_data)
curr_preds = get_preds(curr_model, non_target_data)
base_acc = acc_metric.compute(predictions=base_preds, references=[d["ground_truth"] for d in non_target_data])["accuracy"]
curr_acc = acc_metric.compute(predictions=curr_preds, references=[d["ground_truth"] for d in non_target_data])["accuracy"]
return base_acc - curr_acc
# 6. 迭代训练主流程
def iterative_sft():
# 保存基础模型用于遗忘率基准
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16, device_map=DEVICE)
non_target_data = load_base_data() # 实际替换为非目标任务数据(如MATH)
curr_model = model
for epoch in range(NUM_EPOCHS):
print(f"\n=== 第{epoch+1}轮迭代 ===")
# 生成当前轮次有效数据
valid_data = gen_valid_data(curr_model, TOKENIZER, base_dataset)
print(f"有效样本数:{len(valid_data)}")
if not valid_data:
continue
# 配置训练参数
train_args = TrainingArguments(
output_dir=f"./ckpt/epoch_{epoch+1}",
per_device_train_batch_size=8, learning_rate=LEARNING_RATE,
num_train_epochs=1, logging_steps=5, save_strategy="epoch",
bf16=True, remove_unused_columns=False
)
# 初始化Trainer
data_collator = DataCollatorForLanguageModeling(tokenizer=TOKENIZER, mlm=False)
trainer = Trainer(
model=curr_model, args=train_args, train_dataset=valid_data, data_collator=data_collator
)
# 启动训练
trainer.train()
# 评估遗忘率
forgetting_rate = calc_forgetting(base_model, curr_model, non_target_data)
print(f"遗忘率:{forgetting_rate:.4f}")
# 保存最终模型
curr_model.save_pretrained("./final_iterative_sft_model")
TOKENIZER.save_pretrained("./final_iterative_sft_model")
print("\n模型保存完成")
# 7. 启动训练(首次运行需安装依赖:pip install torch transformers datasets evaluate)
if __name__ == "__main__":
iterative_sft()
五、实用价值:低成本落地的关键建议
基于代码和对比分析,Iterative-SFT的落地无需复杂资源,核心注意三点:
样本规模:示例用若干样本,实际需扩展至10k+(如MMLU训练集1.2万样本),否则有效样本不足,难以体现抗遗忘效果;
非目标任务选择:评估遗忘率时,需用与目标任务无关的数据(如目标是常识问答,非目标用算术推理MATH数据集),避免数据分布重叠导致评估穿越;
参考文献
Chen, H., Razin, N., Narasimhan, K., & Chen, D. (2025). Retaining by doing: The role of on-policy data in mitigating forgetting. arXiv. https://arxiv.org/abs/2510.18874
Shenfeld, I., Pari, J., & Agrawal, P. (2025). RL’s Razor: Why online reinforcement learning forgets less. arXiv. https://arxiv.org/abs/2509.04259