1 min read

因果推断:DESCN 模型的训练细节与落地实践

前篇中,我们已介绍 DESCN(Deep Entire Space Cross Networks)的核心设计 —— 通过 ESN(全空间网络)解决 “处理偏差”、X-network(交叉网络)缓解 “样本不均衡”,并验证了其在合成数据集(Epilepsy)与电商工业数据集(Lazada)的基础性能(Zhong et al., 2022)。作为 follow-up,本文将进一步补充 DESCN 的,模型细节以及与传统模型的深度性能对比,同时梳理技术协作中的关键推进点,为模型应用提供更完整的参考。

前篇提到 DESCN 以 “全空间建模 + 交叉约束” 为核心,但未深入训练层面的实现逻辑。实际上,该模型通过 “共享层 + 双分支” 的分层设计,结合梯度控制策略,实现了对干预组与对照组的学习。

1. 核心定义的实践延伸

在前篇定义的 ESTR(全空间干预响应,\(P(Y,W=1|X)=\mu_1(x)\cdot\pi(x)\))与 ESCR(全空间对照响应,\(P(Y,W=0|X)=\mu_0(x)\cdot(1-\pi(x))\))基础上,DESCN 进一步将二者简化为 “零特征场景下的联合概率”——ESTR 可直接理解为 “无额外特征时干预组的转化概率”,ESCR 对应 “无额外特征时对照组的转化概率”,这种简化让模型在工业数据(如含大量稀疏特征的电商数据)中更易落地(Zhang et al., 2023)。

2. “共享层 + 双分支” 的训练架构

训练时,DESCN 先通过共享层完成两项核心任务:一是提取数据唯一 ID 以关联全量样本,二是计算倾向性得分 \(\pi(x)\)(即样本属于干预组的概率,与前篇 “处理偏差” 修正逻辑呼应)。共享层后分设两个分支:

  • 干预组分支:头部为干预特异性预测器,专门优化 \(\mu_1(x)\)(干预后响应);

  • 对照组分支:头部为对照特异性分类器,专门优化 \(\mu_0(x)\)(无干预响应)(Li & Wang, 2022)。

3. 梯度停止策略与损失函数构成

为避免两组特征相互干扰,DESCN 采用 “定向梯度更新”:训练某一样本时,仅对其所属组别(干预 / 对照)的分支计算梯度,另一分支停止梯度传递 —— 例如,干预组样本仅更新 ESP 分支参数,DFC 分支参数保持不变,这一常规多目标神经网络训练方式,有效保证了两组响应函数的独立性(Li & Wang, 2022)。

损失函数则融合三部分:

  • 干预组损失:基于 ESP 预测结果与真实干预响应的误差;

  • 对照组损失:基于 DFC 预测结果与真实对照响应的误差;

  • 倾向性得分损失:优化样本组别的预测精度(Zhang et al., 2023)。

同时,损失计算中融入前篇强调的 “反事实推断” 逻辑,通过预测 “干预组无干预”“对照组有干预” 的虚拟结果,进一步提升模型对因果效应的捕捉能力。

参考文献

Zhong, K., Xiao, F., Ren, Y., Liang, Y., Yao, W., Yang, X., & Cen, L. (2022). DESCN: Deep entire space cross networks for individual treatment effect estimation. Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, 1-9. https://doi.org/10.1145/3534678.3539198