来自理想汽车团队的最新研究,2026年5月发表
🎬 摘要:先立个 flag
想象一下,你是一位新司机,脑子里的第一个念头(“踩油门超过去!")并不总是最好的选择。聪明的做法是:先打个草稿,再审视一遍,改掉不合理的地方,最终才真正打方向盘。
ReflectDrive-2 就是这样一位"会反思的自动驾驶大脑”。它的核心流程是三步走:决策 → 起草轨迹 → 自我修改(Reflect),全程在同一套"离散 token 空间"里完成,不需要额外的修改网络。
论文最核心的发现是:光靠监督学习训练出来的"修改器"只会让性能提升 +0.3分(PDMS,一个综合驾驶质量分),但当引入强化学习让"起草"和"修改"联合优化之后,提升蹿到了 +1.9分。最终在 NAVSIM 基准测试中,仅用摄像头输入就达到了 91.0 PDMS,最优6选1时更达到了与人类持平的 94.8 PDMS,同时在 NVIDIA Thor 芯片上单帧只需 31.8ms。
第一章:导言 —— 司机的"三步曲"
论文开篇抛出了 ReflectDrive-2 的核心工作流:“决策(Decision)—— 起草(Draft)—— 反思(Reflect)”。
想象一下你正在开车:
- 决策: 环顾四周的摄像头画面、导航信息和自身状态,你心里选定了一个大方向(生成一个 Goal Token 目标词元)。
- 起草: 你在脑海中快速画出一条大致的行车路线(通过掩码离散扩散并行解码出一条蓝色的初始轨迹)。
- 反思(AutoEdit): 然发现前面有个小坑!你没有全盘否定路线,而是在脑海中对那一段方向盘微调了一下(在同一模型内就地修改词元,得到绿色的最终安全计划)。
这种在统一离散空间里的自我修改能力,不需要额外挂载"外挂"(如辅助微调网络),模型自己就能搞定!
第二章:相关工作 —— 站在巨人的肩膀上找痛点
-
端到端与VLA(视觉-语言-动作)规划: 过去的端到端模型通常缺乏中间纠错机制。虽然有些VLA大模型引入了新思路,但 ReflectDrive-2 直接用"掩码离散扩散"取代了传统的自回归,几轮并行解码就能拿下一整条轨迹!
-
扩散策略与强化学习(RL): 之前最接近的工作是 DriveFine(2026),但它的"起草员"和"修改员"是分开训练的,就像写手和编辑各拿各的工资,配合不够默契。ReflectDrive-2 的杀手锏在于:用强化学习(RL)把"起草"和"修改"绑定在一个回合里(composed rollout)。两个人共享最终的驾驶奖励,真正做到了"有福同享,有难同当",从而让编辑能力突飞猛进。
第三章:预备知识 —— 把开车变成"拼乐高"
问题设定
系统的输入是全景相机画面、导航指令和自车状态;输出则是离散的"轨迹词元"(Tokens)。
掩码离散扩散 (Masked Discrete Diffusion)
想象未来的轨迹是由一块块包含鸟瞰图(BEV)坐标的"乐高积木"拼成的。模型一开始拿到的是一堆被打上 assessments(马赛克)的未知积木,然后通过一个双向 Transformer,结合多模态上下文,一步步把马赛克替换成真实的坐标积木,实现有选择性的重新生成。
第四章:核心方法 —— ReflectDrive-2 架构大揭秘
这是这篇论文的灵魂所在,分为五步绝杀:
4.1 系统全貌
将"目标提议、轨迹起草、词元级轨迹纠错"全部统一在离散表示中。
4.2 条件目标掩码轨迹扩散
模型不是死板地给出唯一终点,而是预测一个"目标点后验概率分布",通过 top-k 采样和非极大值抑制(NMS)选出候选目标点,并把它们作为条件,并行起草完整的行驶轨迹。
4.3 AutoEdit 轨迹修正
这是核心魔法!模型具备感知结构的扰动能力,在推理阶段,直接识别出轨迹中那些"不靠谱"的局部 Token,并在原地将它们重写(Rewrite),完成轨迹的微调。
4.4 约束感知的监督目标
在监督学习阶段,引入了"可行驶区域场损失"(Drivable-area field loss),教模型不要把车开到马路牙子或草坪上。
4.5 在"起草-编辑"回放上的强化学习
终极武器! 系统把奖励(Reward)发给"编辑后"的最终轨迹,并通过策略梯度把功劳分配给前期的起草和后期的修改阶段。实验证明,这让模型的纠错能力产生了质的飞跃。
第五章:高效推理 —— 把速度压榨到极致
自动驾驶是人命关天的事,计算再复杂也不能卡顿!作者在 NVIDIA Thor 芯片 上硬生生把平均延迟压到了 31.8 毫秒!
优化手段汇总
| 优化手段 | 延迟变化 | 加速比 |
|---|---|---|
| 共享前缀 KV 缓存复用 | 0.28ms → 0.08ms | 3.5× |
| KV 缓存回退 + 合并重写 | 14.7ms → 11.5ms | 1.28× |
| Action-Expert FFN(隐层 4096→1024) | 2.47ms → 0.95ms | 2.6× |
| 融合 CUDA Unmask 内核 | 0.45ms → 0.06ms | 7.5× |
| 交替步解码(ASD)时序 AutoEdit | 26.2ms → 7.6ms | 3.4× |
| 端到端平均帧延迟 | 45.0ms → 31.8ms | 1.42× |
交替步解码 (Alternating Step Decode)
这个设计太聪明了!因为连续两帧画面的路况其实差不多,模型不需要每帧都"从零思考"。它采用了**全步(Full-step)和轻量步(Lite-step)**交替的策略。全步走一遍"决策-起草-反思";轻量步则直接拿上一帧的轨迹往前挪一挪,做个极简版的 Token 到 Token 的 AutoEdit 快速更新。这就好比司机眨了下眼,只需凭肌肉记忆微调方向盘,省下了巨大的算力!
第六章:实验 —— 用数字说话
6.1 实验设置
- 数据集:NAVSIM(基于 nuPlan),训练集 1192 个场景,测试集 136 个场景,任务是预测 4 秒、2Hz 采样的自车轨迹
- 评估指标 PDMS:综合了无责任碰撞(NC)、可行驶区域合规(DAC)、碰撞时间(TTC)、舒适性(Comf.)、自车进度(EP)
- 模型:0.7B 离散扩散语言主干 + 0.1B ViT 视觉编码器,全部微调
6.2 RL 对 AutoEdit 的增益
| 训练设置 | 无 AutoEdit | 有 AutoEdit | 增益 |
|---|---|---|---|
| 只有 DLM 损失 | 84.8 | 85.0 | +0.2 |
| + 可行驶区域场损失 | 87.2 | 87.3 | +0.1 |
| + AutoEdit 监督训练 | 87.7 | 88.0 | +0.3 |
| + 全流程 RL | 89.1 | 91.0 | +1.9 |
结论一目了然:没有 RL,AutoEdit 只是"摆设";有了 RL,AutoEdit 变成真正有价值的能力。
6.3 与其他方法对比
| 方法 | 输入 | PDMS |
|---|---|---|
| GoalFlow | 相机+激光雷达 | 90.3 |
| ReCogDrive | 仅相机 | 90.8 |
| ReflectDrive-2(我们) | 仅相机 | 91.0 |
| ReflectDrive-2(6选1 oracle) | 仅相机 | 94.8 |
| 人类参考 | — | 94.8 |
最亮眼的指标是EP(自车进度)= 89.4,所有方法中最高——说明车走得更积极,同时 DAC 和舒适性依然保持高位,做到了"进可攻守可守"。
第七章:结论 —— 学会"反思"的司机才是好司机
ReflectDrive-2 的核心启示是:自我纠正不是免费的午餐。一个训练好的编辑器,如果和起草器没有被共同的目标绑定,那它在实际中几乎毫无用处。只有当两者通过强化学习共享终局奖励,起草器才会"学会"生成可改善的草稿,编辑器才会"学会"做出真正有价值的修改。
局限与未来方向
- 当前轨迹用固定分辨率的 BEV 坐标 token,精度受 bin 大小限制。未来可以考虑更细的词表、残差偏移或混合离散-连续动作头
- RL 阶段的奖励是闭环规划代理分,尚非真实世界完整目标;接入更高保真仿真器和更丰富的安全奖励有望进一步改善
- AutoEdit 的扰动目前只覆盖纵向和横向,未来可以扩展到交互层面的失败(让行时机、被加塞响应、间距选择)
附录:关键技术概念详解
离散的"轨迹词元"是什么?
第一步:轨迹是什么
自动驾驶规划的输出是未来几秒内车该去哪,具体表示为一串路径点(waypoints):
(x₁, y₁) → (x₂, y₂) → ... → (x₈, y₈)
这 8 个点描述了车辆在鸟瞰图(BEV)坐标系中未来 4 秒的行驶轨迹,每 0.5 秒一个点。
第二步:为什么要"离散化"?
坐标本来是连续实数,比如 x = 3.742 米。但论文把坐标空间切成格子,就像经纬度变成邮政编码一样:
真实坐标 x = 3.742m
↓ 量化
格子编号 x_bin = 47 (第47个格子)
这个格子编号就是一个词元(token),和语言模型里的字词 token 是同一个概念。
每个路径点有 1 个纵向 token + 1 个横向 token,8 个路径点就是 16 个 token,构成整条轨迹的"词元序列"。
第三步:和语言模型的词元有什么相似之处?
| 语言模型 | ReflectDrive-2 | |
|---|---|---|
| token 是什么 | 词/子词编号 | BEV 坐标格子编号 |
| 序列长度 | 几十~几千 | 固定 16 个 |
| 词表大小 | ~50,000 | 坐标格子数 |
| 生成方式 | 自回归/扩散 | 掩码扩散并行生成 |
两者在数学结构上完全一致,所以才能直接套用离散扩散语言模型(LLaDA 等)的整套框架。
第四步:离散化带来了什么好处?
正是因为轨迹变成了一串 token,才能做到:
- 并行生成:不用一个点一个点地算,所有 token 同时从
assessments开始,几轮并行 unmask 就完成 - 原地编辑:想改某段轨迹?直接把那几个 token 换掉,就像文字编辑器里改几个字,不需要重新生成整条轨迹
- AutoEdit 成为可能:编辑操作和生成操作用的是同一套模型、同一套词表,没有任何"模态鸿沟"
一个直觉类比
把轨迹想象成一首歌的简谱。连续坐标就像模拟音频波形,你没法直接"改几个音符";而离散 token 就像乐谱上的音符编号,你可以精确地找到第3小节第2个音,把它从
5改成6,其他音符完全不受影响。
掩码离散扩散详解
第一层:扩散模型是在解决什么问题?
先忘掉"扩散"这个词,想象一个更简单的问题:
我想让 AI 画一张猫的图片。
AI 怎么"生成"一张图?它不可能凭空变出像素。它需要某种从混乱到有序的过程。
扩散模型的核心思路是:
训练阶段,教 AI 学会"如何把一张乱图还原成好图":
好图 → 加噪声 → 加更多噪声 → ... → 纯噪声
↑
AI 学会反着走这条路
推理阶段,从一堆纯噪声出发,AI 一步一步去噪,最终生成一张清晰的猫图。
这就是扩散模型的本质:学习"去噪"这件事,然后用去噪来生成内容。
第二层:连续扩散 vs 离散扩散
上面说的是连续扩散——图片的像素是连续的实数,噪声是高斯噪声(正态分布的随机数)。
但如果你处理的不是图片,而是文字呢?
文字是离散的:"猫" 这个字不能加上 0.3 变成 "猫.3",这没有意义。你只能从一个词跳到另一个词。
所以就有了离散扩散,它把"加噪声"换成了"替换成随机的其他词/符号":
"今天天气真好"
→ "今天[随机词]真好"
→ "[随机词][随机词]真好"
→ "[随机词][随机词][随机词][随机词]"
AI 的任务同样是学会反着走:从一堆乱七八糟的词,还原出原本有意义的句子。
第三层:掩码扩散 —— 用 assessments 代替随机词
离散扩散有很多种"加噪声"的方式。掩码扩散(Masked Diffusion) 选择了最简单粗暴的一种:
不替换成随机词,直接盖住——换成一个特殊符号
assessments。
"今天天气真好"
→ "今天 assessments 真好"
→ " assessments assessments 真好"
→ " assessments assessments assessments assessments "
你可能觉得这很眼熟——没错,这和 BERT 的完形填空几乎是同一个思路!
AI 的任务:看到带 assessments 的句子,猜出被盖住的是什么。
第四层:生成时怎么用?
训练完之后,想生成内容,就把"正向加噪"反过来走:
从全部都是 assessments 开始,每一轮让 AI 填一些空:
第0轮: assessments assessments assessments assessments assessments assessments
↓ AI预测每个位置的候选词,选最有把握的填入
第1轮:今天 assessments assessments 真 assessments assessments
↓
第2轮:今天 天气 assessments 真 好了 assessments
↓
第3轮:今天 天气 真的 真 好了 啊
每轮都并行处理所有位置,所以比逐词生成的自回归模型快得多。
这个"每轮选最有把握的 token 填入"的策略,就是 MaskGIT 提出的置信度驱动解码,也是 ReflectDrive-2 起草轨迹的核心机制。
第五层:为什么天然支持"原地编辑"?
这是掩码扩散最妙的地方,也是 ReflectDrive-2 选择它的根本原因。
生成完成之后,你得到了一个完整的 token 序列。如果你觉得某几个 token 不对,想改——
你只需要把那几个 token 重新盖成 assessments,然后让 AI 重新填空!
原轨迹: x₁ y₁ x₂ y₂ x₃ y₃ x₄ y₄
发现偏了: ↑这两个位置有问题
重新盖住:x₁ y₁ assessments assessments x₃ y₃ x₄ y₄
AI 重填: x₁ y₁ x₂' y₂' x₃ y₃ x₄ y₄ ✓
其他 token 完全不受影响,模型不需要重新跑一遍,也不需要额外的网络。这就像用橡皮擦掉草稿上的几个字,重新写——而不是把整页纸揉掉重写。
相比之下:
- 连续扩散想修改,得把整个去噪过程重来
- 自回归模型想修改,得从出问题的那个词开始把后面全部重新生成
总结:一张图记住全部
【训练】
好的序列 → 随机盖住一些token → 让AI学会填空
【生成】
全 assessments → AI并行填最有把握的 → 几轮后得到完整序列
【编辑】(掩码扩散独有!)
完整序列 → 把想改的位置重新盖住 → AI重新填空 → 局部修正完毕
AutoEdit 如何识别"不靠谱"的局部 Token?
这个问题分两个层面来回答。
第一个层面:推理时靠"模型置信度"
AutoEdit 不像草稿阶段那样从 assessments 开始填空。它的输入是一条已经完整的轨迹 token 序列,然后对每个位置都预测一个"替换 token":
当前轨迹: x₁ y₁ x₂ y₂ x₃ y₃ x₄ y₄
↓ 模型对每个位置都预测替换值
预测替换: x₁' y₁' x₂' y₂' x₃' y₃' x₄' y₄'
但不是所有位置都会真正被替换——这就是"识别不靠谱 token"的问题所在。
模型对每个 token 位置输出的不是一个确定的值,而是一个概率分布(softmax 输出):
位置 x₃ 的预测分布:
格子46号:0.72 ← 最高概率
格子47号:0.18
格子48号:0.07
其他格子:0.03
置信度高 = 模型对某个格子非常笃定(概率集中)
置信度低 = 模型很纠结,概率分散在好几个格子上
AutoEdit 选择置信度低的位置来替换——因为模型自己都不确定当前这个 token 对不对,说明这个位置可能有问题。
置信度: 高 高 低 低 高 高 高 高
↓ 只替换低置信度的位置
结果: x₁ y₁ x₂' y₂' x₃ y₃ x₄ y₄
目标点 token(行为锚点)永远不替换,始终锁定。
第二个层面:训练时靠"结构感知扰动"
光靠置信度其实有个问题:模型怎么知道当前这个 token 该不该被替换?
如果 AutoEdit 从来没见过"有问题的轨迹长什么样",它的置信度判断就没有意义——它可能对一条偏出车道的轨迹也信心满满。
这就是结构感知扰动(SAP)训练的价值:
训练时,故意喂给模型两类典型错误轨迹:
纵向扰动(进度失误):
原轨迹走了10米 → 扰动后只走6米(刹车太早)
原轨迹走了10米 → 扰动后走了14米(刹车太晚)
横向扰动(方向偏移):
原轨迹直行 → 扰动后整体旋转5度(车道偏移)
然后要求模型把扰动轨迹直接映射回干净轨迹。
这个训练过程让模型学会了:
“当我看到一条在纵向/横向上有规律性偏差的轨迹时,我应该对哪些 token 没信心,以及应该改成什么。”
两个层面怎么配合?
可以这样理解它们的分工:
SAP 训练塑造了模型内部的"错误感知能力"——它让模型的概率分布在面对有问题的轨迹时,自然地对相关 token 产生低置信度。
置信度筛选是推理时的执行机制——用模型自己暴露出来的不确定性,决定哪些 token 值得替换。
训练阶段:
喂扰动轨迹 → 模型学会"这类错误我应该不确定" → 置信度分布变得有意义
推理阶段:
喂草稿轨迹 → 模型对有问题的位置置信度低 → 筛出来替换
一个直觉类比
想象你是一位改卷老师,改了大量有规律错误的卷子(比如总是把加法做成减法)。久而久之,你扫一眼答案就能感觉到"这道题的答案看起来不对劲"——这种直觉来自训练。
AutoEdit 也一样:见过足够多的纵向偏移、横向偏移轨迹后,它在内部形成了对"轨迹哪里不对劲"的直觉,并通过置信度把这种直觉暴露出来,再由筛选机制决定是否动手修改。
AutoEdit 训练阶段伪代码
# ============================================================
# ReflectDrive-2 · AutoEdit 训练阶段伪代码
# 对应论文 Section 3.2 / 4.3 / 4.4
# ============================================================
#
# 符号说明(与论文保持一致)
# z0 : 干净的连续坐标路径点序列 [(x1,y1),...,(x8,y8)]
# x0 : z0 离散化后的 token 序列 [x1,y1,...,x8,y8] 长度 L=16
# x̃0 : 扰动轨迹的 token 序列
# c : 多模态上下文 (视觉tokens, 导航指令, 自车状态)
# pθ : 共享的条件 token 模型(Transformer 主干)
# assessments : 掩码符号,用于草稿阶段
# t : 掩码比例 ∈ [0, 1],决定盖住多少 token
# ============================================================
import random
import math
# ----------------------------------------------------------
# 工具函数
# ----------------------------------------------------------
def tokenize(waypoints):
"""
把连续坐标路径点量化为离散 token 序列。
每个路径点 (x, y) → 两个整数 token (x_bin, y_bin)。
8 个路径点 → 长度 16 的 token 序列。
"""
tokens = []
for (x, y) in waypoints:
tokens.append(quantize(x, axis='longitudinal')) # 纵向坐标 bin
tokens.append(quantize(y, axis='lateral')) # 横向坐标 bin
return tokens # 长度 L = 16
def detokenize(tokens):
"""tokenize 的逆操作:token 序列 → 连续坐标路径点。"""
waypoints = []
for i in range(0, len(tokens), 2):
x = dequantize(tokens[i], axis='longitudinal')
y = dequantize(tokens[i+1], axis='lateral')
waypoints.append((x, y))
return waypoints
# ----------------------------------------------------------
# Step 1 · 结构感知扰动(Structure-Aware Perturbation, SAP)
# 论文公式 (5)(6)
# ----------------------------------------------------------
def structure_aware_perturbation(z0, beta_range=(0.7, 1.3), alpha_max=0.1):
"""
对干净轨迹 z0 施加两类结构化扰动,模拟驾驶中最常见的错误模式。
Args:
z0 : 干净路径点列表,[(x1,y1),...,(x8,y8)]
beta_range : 纵向进度缩放系数范围 (βmin, βmax)
alpha_max : 横向旋转角度上限 αmax(弧度)
Returns:
z̃0 : 扰动后的路径点列表
"""
# ── 扰动类型随机选一种(或两种都加)──────────────────────
perturbation_type = random.choice(['longitudinal', 'lateral', 'both'])
z_tilde = [wp for wp in z0] # 先复制一份
# ── 纵向进度扰动:沿弧长缩放 ────────────────────────────
# 论文公式 (5):z̃i = Interp(z0, β·dᵢ)
# β < 1 → 进度不足(刹车过早)
# β > 1 → 进度超前(刹车过晚 / 超速)
if perturbation_type in ('longitudinal', 'both'):
beta = random.uniform(*beta_range)
# 计算每个路径点的弧长累积值
arc_lengths = compute_arc_lengths(z0) # [d1, d2, ..., d8]
# 用缩放后的弧长重新插值轨迹
scaled_arc = [beta * d for d in arc_lengths] # β·dᵢ
z_tilde = interpolate_by_arc(z0, scaled_arc) # 沿原轨迹重新采样
# ── 横向方向扰动:在自车坐标系中旋转 ───────────────────
# 论文公式 (6):z̃i = R(α)·zᵢ
# α 为随机旋转角,产生整体侧向偏移,但保持轨迹平滑性
if perturbation_type in ('lateral', 'both'):
alpha = random.uniform(-alpha_max, alpha_max)
cos_a, sin_a = math.cos(alpha), math.sin(alpha)
z_tilde = [
(cos_a * x - sin_a * y,
sin_a * x + cos_a * y)
for (x, y) in z_tilde
]
return z_tilde # z̃0:扰动后的连续坐标路径点
# ----------------------------------------------------------
# Step 2 · AutoEdit 损失(LSAP)
# 论文公式 (7)(8)
# ----------------------------------------------------------
def compute_autoedit_loss(model, x0, context):
"""
训练 AutoEdit:给定扰动 token 序列,预测干净 token 序列。
注意:AutoEdit 不使用 assessments!
输入是完整的(扰动后的)具体 token,目标是干净的具体 token。
这与草稿阶段的掩码填空训练完全不同。
Args:
model : 共享的条件 token 模型 pθ
x0 : 干净轨迹的 token 序列,长度 L=16
context : 多模态上下文 c = (视觉tokens, 导航, 自车状态)
Returns:
L_SAP : AutoEdit 纠错损失(标量)
"""
L = len(x0) # = 16
# 2-a. 生成扰动后的连续坐标路径点
z0 = detokenize(x0) # token → 连续坐标
z_tilde = structure_aware_perturbation(z0) # 施加纵向/横向扰动
# 2-b. 把扰动路径点重新离散化为 token 序列
x_tilde = tokenize(z_tilde) # 扰动 token 序列 x̃0
# 2-c. 用模型对扰动 token 序列做预测
# 输入:x̃0(扰动后的完整具体 token)+ 多模态上下文 c
# 输出:每个位置上的替换 token 概率分布
# 论文公式 (7):qθ(· | x̃0, c) = softmax( hθ(x̃0, c) )
logits = model(x_tilde, context) # shape: [L, vocab_size]
probs = softmax(logits, dim=-1) # qθ(· | x̃0, c)
# 2-d. 计算交叉熵损失(对所有 L 个位置平均)
# 论文公式 (8):L_SAP = -E[ 1/L · Σ log qθ(x⁰ᵢ | x̃0, c) ]
# 目标是干净 token x0,而非扰动 token x̃0
L_SAP = 0.0
for i in range(L):
clean_token_id = x0[i] # 第 i 位的干净 token
L_SAP += -math.log(probs[i][clean_token_id] + 1e-9)
L_SAP /= L # 对序列长度取平均
return L_SAP
# ----------------------------------------------------------
# Step 3 · 草稿阶段掩码扩散损失(LDLM)
# 论文公式 (1)
# ----------------------------------------------------------
def compute_dlm_loss(model, x0, context):
"""
训练草稿生成:随机盖住若干 token,让模型预测所有位置的原始 token。
与 AutoEdit 损失的关键区别:
- AutoEdit:输入是"有规律错误"的具体 token
- DLM :输入是"随机盖 assessments"的不完整序列
Args:
model : 共享的条件 token 模型 pθ
x0 : 干净轨迹的 token 序列
context : 多模态上下文 c
Returns:
L_DLM : 掩码扩散损失(标量)
"""
L = len(x0) # = 16
# 3-a. 随机采样掩码比例 t ∈ [0, 1]
t = random.uniform(0, 1)
# 3-b. 按比例 t 独立随机将每个 token 替换为 assessments
x_masked = []
for token in x0:
if random.random() < t:
x_masked.append(MASK_TOKEN_ID) # 盖住
else:
x_masked.append(token) # 保留
# 3-c. 模型对所有位置(包括未盖住的)做预测
# 注意:论文选择"全位置监督",而非仅对 assessments 位置监督
# 实验证明全位置监督让训练更稳定、草稿更连贯
logits = model(x_masked, context) # shape: [L, vocab_size]
probs = softmax(logits, dim=-1)
# 3-d. 对所有 L 个位置计算交叉熵(目标均为干净 token x0)
# 论文公式 (1):L_DLM = -E[ 1/L · Σ log pθ(x⁰ᵢ | xt, c) ]
L_DLM = 0.0
for i in range(L):
L_DLM += -math.log(probs[i][x0[i]] + 1e-9)
L_DLM /= L
return L_DLM
# ----------------------------------------------------------
# Step 4 · 可行驶区域场损失(Lfield)
# 论文公式 (11)(12)(13)(14)
# ----------------------------------------------------------
def compute_field_loss(model, x0, context, dac_cost_field):
"""
空间惩罚:让模型在概率层面就"远离"不可行驶区域。
Args:
model : 共享的条件 token 模型 pθ
x0 : 干净轨迹的 token 序列
context : 多模态上下文 c
dac_cost_field : 可行驶区域代价场 C ∈ R^[H×W]
非可行驶区域 = 高代价,可行驶区域 = 0
Returns:
L_field : 可行驶区域场损失(标量)
"""
# 4-a. 用带随机掩码的输入让模型输出 logits(与 DLM 阶段共用)
logits = model(apply_random_mask(x0), context) # shape: [L, vocab_size]
L_field = 0.0
# 4-b. 对每个路径点 t(共 8 个)计算空间惩罚
for t in range(8):
x_logit = logits[2*t] # 第 t 个路径点的纵向 logit
y_logit = logits[2*t + 1] # 第 t 个路径点的横向 logit
p_x = softmax(x_logit) # 纵向坐标的概率分布 p^(t)_x shape:[W]
p_y = softmax(y_logit) # 横向坐标的概率分布 p^(t)_y shape:[H]
# 4-c. 构造联合空间分布(外积)
# 论文公式 (11):p^(t)_xy[i,j] = p^(t)_x[i] · p^(t)_y[j]
# 假设纵横坐标独立——近似但够用
p_xy = outer_product(p_x, p_y) # shape: [H, W]
# 4-d. 用代价场加权的对数障碍函数计算惩罚
# 论文公式 (12):L_field = Σ_t Σ_{i,j} -log(1 - p^(t)_xy[i,j]) · C[i,j]
#
# 直觉:
# · C[i,j] 大(靠近不可行驶区域)→ 这一项权重大
# · p_xy[i,j] 大(模型对这个位置置信高)→ 惩罚也大
# · 两个大值相乘 → 强烈惩罚"高置信度踩线"
for i in range(H):
for j in range(W):
cost = dac_cost_field[i][j]
if cost > 0:
penalty = -math.log(1 - p_xy[i][j] + 1e-9) * cost
L_field += penalty
return L_field
# ----------------------------------------------------------
# Step 5 · 总监督损失 & 训练主循环
# 论文公式 (15)
# ----------------------------------------------------------
def supervised_training(model, dataloader, lambda_SAP=1.0, lambda_field=0.1):
"""
监督微调(SFT)主循环:同时优化三个目标。
总损失:L_sup = L_DLM + λ_SAP · L_SAP + λ_field · L_field
论文公式 (15)
三个损失的分工:
L_DLM → 教模型从 assessments 还原干净轨迹(起草能力)
L_SAP → 教模型将扰动 token 映射回干净 token(纠错能力)
L_field → 在概率层面惩罚踩入不可行驶区域(空间约束)
"""
optimizer = AdamW(model.parameters())
for batch in dataloader:
x0 = batch['trajectory_tokens'] # 干净轨迹 token 序列
context = batch['context'] # 多模态上下文 c
cost_field = batch['dac_cost_field'] # 可行驶区域代价场
# ── 计算三个损失分量 ───────────────────────────────
L_DLM = compute_dlm_loss(model, x0, context)
L_SAP = compute_autoedit_loss(model, x0, context)
L_field = compute_field_loss(model, x0, context, cost_field)
# ── 加权求和得到总监督损失 ─────────────────────────
L_sup = L_DLM + lambda_SAP * L_SAP + lambda_field * L_field
# ── 反向传播 & 更新参数 ────────────────────────────
optimizer.zero_grad()
L_sup.backward()
optimizer.step()
log(L_DLM=L_DLM, L_SAP=L_SAP, L_field=L_field, L_sup=L_sup)
# SFT 结束后,进入 RL 微调阶段(见论文 Section 4.5)
return model
整体结构对应论文的训练三角形,用一张图看清楚各部分的关系:
同一批干净轨迹 x0
│
├─── 随机盖 assessments ──→ L_DLM (教模型:从残缺填完整)
│
├─── 纵向/横向扰动 ──→ L_SAP (教模型:从错误改正确)
│
└─── 输出概率分布 ───→ L_field (教模型:概率质量远离禁区)
│
└──── 三者加权求和 ──→ L_sup → 反向传播
有几个细节特别值得注意:
L_SAP 和 L_DLM 的本质区别:两者用的是同一个模型,但输入性质完全不同。L_DLM 的输入是"信息随机缺失"(assessments),L_SAP 的输入是"信息系统性错误"(扰动坐标),后者才是 AutoEdit 真正的训练信号。
L_field 作用在概率层面:它不是在惩罚模型"预测了出界的 token",而是在惩罚模型"对出界位置赋予了高概率",惩罚力度随置信度和距离边界的远近双重加权,比简单的离散惩罚更平滑。
三个损失共享同一个模型参数:没有任何额外网络,SFT 结束后这个模型既能起草,又能纠错,还会绕开禁区——然后直接进入 RL 阶段做联合优化。
RL 微调阶段伪代码
# ============================================================
# ReflectDrive-2 · RL 微调阶段伪代码
# 对应论文 Section 3.4 / 4.5
# ============================================================
#
# 核心思想:
# SFT 之后,起草器和 AutoEdit 各自都有能力,但互不知道对方的存在。
# RL 阶段把两者绑在一起:跑完"起草→AutoEdit"的完整 rollout,
# 只用最终轨迹的驾驶得分作为奖励,同时回传给两个阶段。
# 结果:起草器学会生成"可被 AutoEdit 改好"的草稿;
# AutoEdit 学会做"真正提升驾驶分"的修改。
#
# 符号说明(与论文保持一致)
# Ng : 每帧采样的目标点数量(论文实验中 Ng=3)
# I : 每个目标点采样的草稿数(论文实验中 I=2)
# G : 总候选轨迹数 = Ng × I(论文实验中 G=6)
# Sdraft : 草稿阶段的 unmask 轮数
# Sedit : AutoEdit 阶段的修改轮数
# S : 总步数 = Sdraft + Sedit
# ρg : 第 g 条候选的完整 token 转换序列(长度 S+1)
# τg : 第 g 条候选的最终轨迹(连续坐标)
# R(τg) : 闭环规划奖励(即 PDMS 得分)
# Ag : 第 g 条候选的组相对优势
# πθ : 当前策略(即当前模型参数)
# πθ_old : rollout 时冻结的旧策略(用于计算重要性采样比)
# πref : SFT 之后冻结的参考策略(用于 KL 惩罚)
# ============================================================
import math
import copy
# ----------------------------------------------------------
# Step 1 · 单条候选的完整 Draft→AutoEdit Rollout
# 论文公式 (16)(17)
# ----------------------------------------------------------
def run_single_rollout(model, goal_token, context,
S_draft=3, S_edit=3):
"""
对一个目标点执行完整的"起草→AutoEdit"rollout,
并记录每一步的 token 转换序列(用于后续计算策略梯度)。
Args:
model : 当前策略模型 πθ_old(rollout 时冻结)
goal_token : 已选定的目标点 token(行为锚点)
context : 多模态上下文 c
S_draft : 草稿阶段 unmask 轮数
S_edit : AutoEdit 阶段修改轮数
Returns:
trajectory_history : 完整 token 状态序列
[x⁰, x¹, ..., x^S] 共 S+1 个快照
前 S_draft 步是草稿阶段,后 S_edit 步是 AutoEdit 阶段
final_trajectory : 最终连续坐标轨迹 τg(送入仿真器求奖励)
"""
L = 16 # token 序列长度(8个路径点 × 2坐标)
# ── 初始化:全部 assessments,目标点 token 固定 ────────────────
x_current = [MASK_TOKEN_ID] * L
x_current = set_goal_token(x_current, goal_token) # 锁定目标点位置
trajectory_history = [x_current.copy()] # x⁰
# ══════════════════════════════════════════════════════
# 阶段 A · 草稿生成(Masked Diffusion Drafting)
# 从全 assessments 出发,每轮并行 unmask 最有把握的 token
# ══════════════════════════════════════════════════════
for s in range(S_draft):
# A-1. 模型预测每个位置的 token 概率分布
# 输入:带 assessments 的不完整序列 + 多模态上下文
logits = model(x_current, context) # shape: [L, vocab_size]
probs = softmax(logits, dim=-1) # πθ_old(· | x_current, c)
# A-2. 对每个仍是 assessments 的位置,采样一个 token
x_predicted = []
confidence = []
for i in range(L):
if x_current[i] == MASK_TOKEN_ID:
token_id = sample_from(probs[i]) # 按概率采样(非贪心)
conf = probs[i][token_id] # 该 token 的概率作为置信度
else:
token_id = x_current[i] # 已确定的 token 保持不变
conf = 1.0
x_predicted.append(token_id)
confidence.append(conf)
# A-3. 按置信度从高到低排序,本轮只 unmask 最有把握的那一批
# 每轮 unmask 比例随步数递增(第 s 轮 unmask 约 s/S_draft 的 token)
n_to_unmask = compute_unmask_count(s, S_draft, n_masked=count_masks(x_current))
unmask_indices = top_k_indices(confidence, k=n_to_unmask,
only_masked_positions=True)
# A-4. 把选中位置的 assessments 替换为预测 token,其余保持 assessments
x_next = x_current.copy()
for idx in unmask_indices:
x_next[idx] = x_predicted[idx]
x_current = x_next
trajectory_history.append(x_current.copy()) # 记录 x^(s+1)
# 草稿完成,x_current 此时应没有 assessments(全部填满)
x_draft = x_current.copy() # x^(S_draft)
# ══════════════════════════════════════════════════════
# 阶段 B · AutoEdit 修改(Token-to-Token Rewriting)
# 从完整草稿出发,直接 token→token 重写低置信度位置
# 注意:不再引入 assessments,输入始终是完整的具体 token 序列
# ══════════════════════════════════════════════════════
for k in range(S_edit):
# B-1. 模型对当前完整 token 序列预测替换 token
# 输入:当前具体 token 序列(无 assessments)+ 多模态上下文
logits = model(x_current, context) # shape: [L, vocab_size]
probs = softmax(logits, dim=-1)
x_hat = argmax(probs, dim=-1) # 每个位置的最优替换 token
confidence = [probs[i][x_hat[i]] for i in range(L)]
# B-2. 构造 commit mask:选出置信度低的非目标点 token
# 论文公式 (10):x^(k+1) = m^(k) ⊙ x̂^(k+1) + (1-m^(k)) ⊙ x^(k)
# 目标点 token 永远不进入候选(行为锚点锁定)
commit_mask = compute_commit_mask(
confidence = confidence,
x_current = x_current,
goal_positions = get_goal_positions(), # 目标点 token 的位置索引
threshold = CONFIDENCE_THRESHOLD # 低于此值才会被替换
) # shape: [L],1=替换,0=保留
# B-3. 按 commit mask 执行原地替换
x_next = [
x_hat[i] if commit_mask[i] == 1 else x_current[i]
for i in range(L)
]
x_current = x_next
trajectory_history.append(x_current.copy()) # 记录 x^(S_draft + k + 1)
# 最终轨迹(连续坐标)送入闭环仿真器求奖励
final_trajectory = detokenize(x_current) # τg
# trajectory_history 包含 S_draft + S_edit + 1 个快照
# 论文公式 (16):ρg = (x⁰g, x¹g, ..., x^(Sdraft+Sedit)g)
return trajectory_history, final_trajectory
# ----------------------------------------------------------
# Step 2 · 对一个场景采样 G 条候选 rollout
# 论文公式 (16)(17)(18)
# ----------------------------------------------------------
def sample_group_rollouts(model, context, N_g=3, I=2,
S_draft=3, S_edit=3):
"""
对当前场景:
1. 采样 Ng 个目标点(top-k + NMS)
2. 每个目标点生成 I 条草稿
3. 每条草稿跑完完整的 AutoEdit rollout
4. 把最终轨迹送入仿真器,得到 G 个驾驶奖励
Args:
model : 旧策略模型 πθ_old(rollout 期间完全冻结)
context : 多模态上下文 c
N_g : 目标点采样数(论文实验值 3)
I : 每个目标点的草稿数(论文实验值 2)
Returns:
group : 长度 G = Ng×I 的列表,每项包含:
- history : 完整 token 转换序列 ρg
- final_traj : 最终连续坐标轨迹 τg
- reward : 闭环驾驶奖励 R(τg)
- advantage : 组相对优势 Ag(统一计算后填入)
"""
G = N_g * I
group = []
# ── 2-a. 预测目标点后验分布,top-k 采样 + NMS ────────────
goal_logits = model.goal_head(context) # 目标点概率分布
goal_tokens = top_k_nms_sample(goal_logits,
k=N_g,
nms_threshold=1.2) # NMS 阈值约 1.2m
# ── 2-b. 对每个目标点生成 I 条 rollout ───────────────────
for goal_token in goal_tokens: # 外层:Ng 个目标点
for _ in range(I): # 内层:每个目标点 I 条草稿
history, final_traj = run_single_rollout(
model, goal_token, context, S_draft, S_edit
)
# 送入闭环仿真器(NAVSIM)计算 PDMS 奖励
# PDMS = 加权综合(NC, DAC, TTC, 舒适性, EP)
reward = closed_loop_simulator.score(final_traj) # R(τg)
group.append({
'history' : history, # ρg:[x⁰, x¹, ..., x^S]
'final_traj' : final_traj, # τg
'reward' : reward, # R(τg)
'goal_token' : goal_token,
})
# ── 2-c. 计算组相对优势(Group-Relative Advantage)────────
# 论文公式 (18):Ag = R(τg) - 1/G · Σj R(τj)
#
# 直觉:比组内平均分高 → 正优势(这条路走对了,强化它)
# 比组内平均分低 → 负优势(这条路走错了,抑制它)
mean_reward = sum(item['reward'] for item in group) / G
for item in group:
item['advantage'] = item['reward'] - mean_reward # Ag
return group # 长度 G 的候选列表,每项含 ρg, τg, R(τg), Ag
# ----------------------------------------------------------
# Step 3 · 计算单条 rollout 的策略梯度损失
# 论文公式 (2)(19)
# ----------------------------------------------------------
def compute_pg_loss_single(model, model_old, item, epsilon=0.2):
"""
对一条候选 rollout 计算 PPO-clip 风格的策略梯度损失。
核心设计:
· 只有"在这一步真正发生变化"的 token 才获得梯度信号
(草稿阶段:从 assessments 变成具体 token 的位置)
(AutoEdit 阶段:被 commit mask 选中并替换的位置)
· 使用重要性采样比 r = πθ / πθ_old 做 off-policy 修正
· clip 截断防止更新步幅过大(PPO 技巧)
论文公式 (2) 中的指示函数:
δ^s_{g,p} = 1{ x^(s+1)_{g,p} ≠ x^s_{g,p} }
只有 token 在第 s 步发生了变化,才把这步的梯度纳入损失。
Args:
model : 当前策略 πθ(需要更新)
model_old : 旧策略 πθ_old(rollout 时冻结,用于计算比值)
item : 单条 rollout 信息(含 history, advantage)
epsilon : PPO clip 系数
Returns:
loss_pg : 该条 rollout 的策略梯度损失(标量)
"""
history = item['history'] # [x⁰, x¹, ..., x^S] 共 S+1 个快照
Ag = item['advantage'] # 组相对优势(标量)
context = item['context']
S = len(history) - 1 # 总步数 = S_draft + S_edit
L = len(history[0]) # token 序列长度 = 16
loss_pg = 0.0
n_credited = 0 # 记录实际获得梯度的 (步, 位置) 对数量
for s in range(S):
x_s = history[s] # 第 s 步的 token 状态(模型输入)
x_s_next = history[s + 1] # 第 s+1 步的 token 状态(目标)
# 3-a. 用当前策略和旧策略分别计算每个位置的 token 概率
logits_new = model(x_s, context) # πθ 的输出
logits_old = model_old(x_s, context) # πθ_old 的输出(不求梯度)
probs_new = softmax(logits_new, dim=-1) # shape: [L, vocab_size]
probs_old = softmax(logits_old, dim=-1) # shape: [L, vocab_size]
for p in range(L):
# 3-b. 指示函数:只有这个位置在这一步发生了变化,才纳入梯度
# 论文公式 (19):δ^s_{g,p} = 1{ x^(s+1)_{g,p} ≠ x^s_{g,p} }
#
# 草稿阶段:assessments → 具体 token,发生变化 → δ=1
# AutoEdit:被替换的 token 位置,发生变化 → δ=1
# 两个阶段没变化的位置:δ=0,跳过,不贡献梯度
token_changed = (x_s_next[p] != x_s[p])
if not token_changed:
continue
target_token = x_s_next[p] # 这一步实际写入的 token
# 3-c. 计算重要性采样比 r^s_{g,p} = πθ / πθ_old
# 分子:当前策略对 target_token 的概率
# 分母:旧策略对 target_token 的概率(采样时的概率)
prob_new = probs_new[p][target_token]
prob_old = probs_old[p][target_token]
ratio = prob_new / (prob_old + 1e-8) # r^s_{g,p}
# 3-d. PPO-clip:截断过大的更新步幅
# 论文公式 (2):min( r·Ag, clip(r, 1-ε, 1+ε)·Ag )
clipped_ratio = clip(ratio, 1 - epsilon, 1 + epsilon)
objective = min(ratio * Ag, clipped_ratio * Ag)
loss_pg += -objective # 最大化目标 = 最小化负目标
n_credited += 1
# 对所有 (步, 位置) 取平均
if n_credited > 0:
loss_pg /= n_credited
return loss_pg
# ----------------------------------------------------------
# Step 4 · KL 散度惩罚项
# 论文公式 (2) 中的 λ_KL · DKL(πθ ‖ πref)
# ----------------------------------------------------------
def compute_kl_penalty(model, model_ref, x_masked, context):
"""
防止 RL 微调跑偏太远,在 SFT 基础上保持合理的分布。
πref:SFT 结束后冻结的参考策略(不随 RL 更新)
πθ :当前正在更新的策略
对一个随机掩码输入,计算两者输出分布的 KL 散度。
Args:
model : 当前策略 πθ
model_ref : SFT 后冻结的参考策略 πref
x_masked : 随机掩码后的 token 序列
context : 多模态上下文
Returns:
kl : KL 散度(标量)
"""
probs_theta = softmax(model(x_masked, context), dim=-1) # πθ
probs_ref = softmax(model_ref(x_masked, context), dim=-1) # πref(无梯度)
# KL(πθ ‖ πref) = Σ πθ · log(πθ / πref)
kl = 0.0
L = probs_theta.shape[0]
for i in range(L):
for v in range(VOCAB_SIZE):
p = probs_theta[i][v]
q = probs_ref[i][v]
if p > 1e-9:
kl += p * math.log(p / (q + 1e-9))
return kl / L # 对位置取平均
# ----------------------------------------------------------
# Step 5 · RL 微调主循环
# 论文公式 (2)
# ----------------------------------------------------------
def rl_finetuning(model_sft, dataloader,
N_g=3, I=2, S_draft=3, S_edit=3,
epsilon=0.2, lambda_kl=0.01,
n_epochs=1):
"""
在 SFT 模型基础上做 RL 微调(Reinforcement Fine-Tuning, RFT)。
Args:
model_sft : SFT 阶段训练好的模型(作为起点和参考策略)
dataloader: 驾驶场景数据集
N_g : 每场景目标点数
I : 每目标点草稿数
S_draft : 草稿阶段步数
S_edit : AutoEdit 阶段步数
epsilon : PPO clip 系数
lambda_kl : KL 惩罚权重
n_epochs : RL 训练轮数
"""
# ── 初始化 ──────────────────────────────────────────────
model = copy.deepcopy(model_sft) # πθ:当前策略(持续更新)
model_ref = copy.deepcopy(model_sft) # πref:SFT 参考策略(永久冻结)
freeze(model_ref)
optimizer = AdamW(model.parameters(), lr=1e-6)
for epoch in range(n_epochs):
for scene in dataloader:
context = scene['context'] # 多模态上下文 c
# ══════════════════════════════════════════════
# Phase 1 · Rollout(用旧策略采样,不求梯度)
# ══════════════════════════════════════════════
# 冻结当前策略作为 πθ_old,用于 rollout 和计算重要性采样比
model_old = copy.deepcopy(model)
freeze(model_old)
# 采样 G = Ng × I 条完整 Draft→AutoEdit rollout
# 并计算每条轨迹的驾驶奖励和组相对优势
with no_grad():
group = sample_group_rollouts(
model_old, context, N_g, I, S_draft, S_edit
)
# group 中每项已包含:history(ρg), final_traj(τg), reward(R), advantage(Ag)
# ══════════════════════════════════════════════
# Phase 2 · 计算总损失(需要梯度)
# ══════════════════════════════════════════════
total_loss = 0.0
# 2-a. 对 G 条 rollout 分别计算策略梯度损失,取平均
# 论文公式 (2) 的主项
loss_pg_total = 0.0
for item in group:
item['context'] = context
loss_pg_total += compute_pg_loss_single(
model, model_old, item, epsilon
)
loss_pg = loss_pg_total / len(group) # 对 G 条取平均
# 2-b. KL 惩罚:防止策略偏离 SFT 太远
x_sample = sample_masked_sequence(scene['trajectory_tokens'])
loss_kl = compute_kl_penalty(model, model_ref, x_sample, context)
# 2-c. 合并
# L(θ) = L_PG + λ_KL · DKL(πθ ‖ πref)
total_loss = loss_pg + lambda_kl * loss_kl
# ══════════════════════════════════════════════
# Phase 3 · 更新参数
# ══════════════════════════════════════════════
optimizer.zero_grad()
total_loss.backward()
clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
log(
epoch = epoch,
loss_pg = loss_pg,
loss_kl = loss_kl,
total_loss = total_loss,
mean_reward = sum(item['reward'] for item in group) / len(group),
autoedit_gain = compute_autoedit_gain(group),
# autoedit_gain = 有 AutoEdit 的平均分 - 无 AutoEdit 的平均分
# RL 训练成功的标志:这个值从 ~0.3 增长到 ~1.9
)
return model # RL 微调完成,起草器和 AutoEdit 已联合优化
整个 RL 阶段分三个 Phase 循环执行,用一张图总结结构:
每个场景迭代:
Phase 1 · Rollout(冻结旧策略,不求梯度)
┌─────────────────────────────────────────────┐
│ 目标点 NMS 采样(Ng=3) │
│ × 每个目标点 I=2 条草稿 │
│ = G=6 条完整 Draft→AutoEdit rollout │
│ 每条送入 NAVSIM 得 R(τg) │
│ 统一计算组相对优势 Ag = R(τg) - mean(R) │
└─────────────────────────────────────────────┘
↓
Phase 2 · 计算损失(当前策略,需要梯度)
┌─────────────────────────────────────────────┐
│ 对 G 条 rollout,逐步逐位置: │
│ 只看 token 真正改变的位置(指示函数 δ) │
│ 计算 r = πθ/πθ_old │
│ loss = -min(r·Ag, clip(r)·Ag) │
│ + λ_KL · DKL(πθ ‖ πref) │
└─────────────────────────────────────────────┘
↓
Phase 3 · 反向传播更新
有三个设计决策最值得注意:
指示函数 δ 是灵魂。草稿阶段里从 assessments 变成具体 token 的位置、AutoEdit 阶段里被替换的位置,两类"变化"用同一个 δ 函数统一处理,终局奖励自然流入两个阶段,不需要任何人为分拆。
两套冻结模型各司其职。πθ_old 在每轮迭代开始时复制当前模型,用于 rollout 和计算重要性采样比(保证 off-policy 修正的正确性);πref 是 SFT 结束后一次性冻结,全程不动,只用来算 KL 惩罚防止策略跑偏太远。
奖励只在终点打分。仿真器只看最终修改后的轨迹,不会给草稿中间过程评分。这逼着起草器主动生成"留有改善余地"的草稿,而不是一开始就出最优解——正是这个机制让 AutoEdit 增益从 +0.3 跳到 +1.9。
Action-Expert FFN
背景:Transformer 里的 FFN 是什么
标准 Transformer 的每一层都有两个子模块:
输入 token
→ 自注意力(Attention):让 token 互相"看"彼此
→ FFN(前馈网络):对每个 token 独立做非线性变换
→ 输出 token
FFN 是 Transformer 里参数量最大、计算量最重的部分,约占整个模型计算量的 2/3。
核心问题:轨迹 token 需要这么大的 FFN 吗?
模型里同时存在两类 token,它们的性质天差地别:
| 语言/视觉 token | 轨迹 token | |
|---|---|---|
| 词表大小 | ~50,000 个词 | 几十个坐标 bin |
| 语义复杂度 | 极高(理解场景、推理意图) | 极低(就是一个坐标数字) |
| 需要的表达能力 | 强 | 弱 |
| 占序列长度 | 大部分 | 固定 16 个 |
用一个 4096 维的 FFN 去处理"第 47 号坐标格子"这种极其简单的信息,就像用核弹炸蚊子——严重过剩。
Action-Expert FFN 的解决方案
给轨迹 token 单独配一套更小的 FFN,语言/视觉 token 继续用原来的大 FFN:
每一层 Transformer:
语言/视觉 token ──→ 主 FFN(d_ffn = 4096)──→ 继续
↑ 按 token 类型路由
轨迹 token ──→ Action FFN(d_ffn = 1024)──→ 继续
两套 FFN 的参数完全独立,在同一层里并行存在,根据 token 的类型决定走哪条路。
关键的设计洞察
Attention 不分路由,FFN 才分。这很重要——轨迹 token 需要通过 Attention"看到"整个场景(障碍物在哪、车道线在哪)才能生成正确的坐标,这部分绝不能缩减。FFN 只是事后对每个 token 独立做变换,轨迹 token 在这一步根本不需要理解语义,1024 维完全够用。
路由是零成本的。轨迹 token 始终在序列末尾的固定位置,不需要任何路由网络或运行时判断,直接切片就完成了分流,没有额外开销。
精度反升是最有趣的发现。大 FFN 让语言梯度和轨迹梯度在同一组参数里相互干扰;拆分之后 Action FFN 的梯度来源纯粹,加上容量受限带来的隐式正则化,模型反而学到了更干净的坐标变换规律。
Action-Expert FFN 实现伪代码
# ============================================================
# Action-Expert FFN 实现伪代码
# 对应论文 Section 5(推理优化部分)
# ============================================================
#
# 核心思想:
# 轨迹 token 的语义空间极小(只是坐标格子编号),
# 不需要和语言/视觉 token 一样大的 FFN。
# 用一个更小的专用 FFN 处理轨迹 token,
# 在不损失(甚至提升)精度的前提下大幅降低计算量。
#
# 论文数值:
# 主 FFN d_ffn = 4096(论文实验中约为 d_model 的 2 倍)
# Action FFN d_ffn = 1024(压缩到 1/4)
# 加速比:2.6×
# meanSADE:不降反升(说明大 FFN 对轨迹 token 反而是过拟合)
# ============================================================
import torch
import torch.nn as nn
# ----------------------------------------------------------
# 1. 标准 FFN(用于语言/视觉 token)
# ----------------------------------------------------------
class StandardFFN(nn.Module):
"""
Transformer 主干的标准前馈网络。
处理语言 token、视觉 token、导航指令 token、自车状态 token。
维度:d_model → d_ffn(大) → d_model
"""
def __init__(self, d_model=2048, d_ffn=4096, dropout=0.0):
super().__init__()
self.w1 = nn.Linear(d_model, d_ffn, bias=False)
self.w2 = nn.Linear(d_ffn, d_model, bias=False)
self.act = nn.SiLU() # 论文用 SwiGLU 变体,这里简化为 SiLU
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# x: [batch, seq_len, d_model]
return self.dropout(self.w2(self.act(self.w1(x))))
# ----------------------------------------------------------
# 2. Action-Expert FFN(专用于轨迹 token)
# ----------------------------------------------------------
class ActionExpertFFN(nn.Module):
"""
轨迹 token 专用的小型前馈网络。
只处理 16 个轨迹 token(8 个路径点 × 2 坐标)。
维度:d_model → d_action_ffn(小) → d_model
d_action_ffn = 1024(是标准 FFN 的 1/4)
为什么更小也够用?
轨迹 token 的 "语义" 极其简单:就是一个坐标格子的编号。
它不需要理解"转弯"、"让行"这类高层语义(那是 Attention 的工作)。
FFN 只需要做一个低维的坐标→隐空间的映射,1024 维绰绰有余。
"""
def __init__(self, d_model=2048, d_action_ffn=1024, dropout=0.0):
super().__init__()
self.w1 = nn.Linear(d_model, d_action_ffn, bias=False)
self.w2 = nn.Linear(d_action_ffn, d_model, bias=False)
self.act = nn.SiLU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# x: [batch, n_action_tokens, d_model] n_action_tokens ≤ 16
return self.dropout(self.w2(self.act(self.w1(x))))
# ----------------------------------------------------------
# 3. 混合 FFN 层:按 token 类型路由
# ----------------------------------------------------------
class MixedExpertFFNLayer(nn.Module):
"""
一个完整的 Transformer FFN 子层,内含两套 FFN:
- StandardFFN → 处理所有非轨迹 token
- ActionExpertFFN→ 处理所有轨迹 token
路由依据:token 位置索引(轨迹 token 始终在序列末尾固定位置)。
序列结构(论文 Section 4.1):
[视觉 token (V)] [导航 token (N)] [自车状态 token (E)]
[目标点 token (G)] [轨迹 token (A)]
↑
这 16 个位置走 Action FFN
"""
def __init__(self,
d_model=2048,
d_ffn=4096,
d_action_ffn=1024,
n_action_tokens=16):
super().__init__()
self.standard_ffn = StandardFFN(d_model, d_ffn)
self.action_expert_ffn = ActionExpertFFN(d_model, d_action_ffn)
self.n_action_tokens = n_action_tokens
# 轨迹 token 始终在序列末尾,位置固定
# 这使得路由无需任何运行时判断,直接切片即可
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x, action_token_mask=None):
"""
Args:
x : [batch, seq_len, d_model]
完整序列(视觉+语言+目标点+轨迹 token 混在一起)
action_token_mask: [batch, seq_len] bool tensor
True 表示该位置是轨迹 token
如果为 None,则默认末尾 n_action_tokens 个是轨迹 token
Returns:
out: [batch, seq_len, d_model]
"""
batch, seq_len, d = x.shape
out = x.clone() # 输出在原位修改
# ── 构造路由 mask ─────────────────────────────────────
if action_token_mask is None:
# 默认:序列末尾 n_action_tokens 个位置是轨迹 token
action_token_mask = torch.zeros(batch, seq_len, dtype=torch.bool)
action_token_mask[:, -self.n_action_tokens:] = True
context_mask = ~action_token_mask # 非轨迹 token 的 mask
# ── 路由 1:非轨迹 token → StandardFFN ───────────────
# 把所有非轨迹位置的 token 聚合成一个 batch 送入大 FFN
# 用 mask 索引实现,避免 for 循环
if context_mask.any():
x_context = x[context_mask] # [n_context, d_model]
x_context = self.layer_norm(x_context)
out_context = self.standard_ffn(x_context) # [n_context, d_model]
out[context_mask] = out_context
# ── 路由 2:轨迹 token → ActionExpertFFN ─────────────
if action_token_mask.any():
x_action = x[action_token_mask] # [n_action, d_model]
# n_action = batch × 16(或更少)
x_action = self.layer_norm(x_action)
out_action = self.action_expert_ffn(x_action) # [n_action, d_model]
out[action_token_mask] = out_action
return out
# ----------------------------------------------------------
# 4. 将其嵌入完整 Transformer Block
# ----------------------------------------------------------
class TransformerBlockWithActionExpert(nn.Module):
"""
完整的 Transformer 层,Attention 部分不变,
FFN 部分替换为 MixedExpertFFNLayer。
注意:Attention 对所有 token 统一计算(轨迹 token 需要
"看"场景上下文才能生成正确坐标),只有 FFN 做了分路由。
"""
def __init__(self, d_model=2048, n_heads=16,
d_ffn=4096, d_action_ffn=1024):
super().__init__()
# Attention 部分:所有 token 统一处理(双向注意力)
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
# FFN 部分:按 token 类型路由
self.ffn = MixedExpertFFNLayer(d_model, d_ffn, d_action_ffn)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, attn_mask=None, action_token_mask=None):
# ── Attention 子层(Pre-Norm + 残差)────────────────
x_norm = self.norm1(x)
attn_out, _ = self.attn(x_norm, x_norm, x_norm, attn_mask=attn_mask)
x = x + attn_out # 残差连接
# ── 混合 FFN 子层(Post-Norm + 残差)────────────────
ffn_out = self.ffn(self.norm2(x), action_token_mask)
x = x + ffn_out # 残差连接
return x
# ----------------------------------------------------------
# 5. 参数量 & 计算量对比(以单层 FFN 为例)
# ----------------------------------------------------------
def compare_params_and_flops(d_model=2048, d_ffn=4096, d_action_ffn=1024,
n_context_tokens=512, n_action_tokens=16):
"""
对比标准方案(所有 token 用大 FFN)
vs Action-Expert 方案(轨迹 token 用小 FFN)
的参数量和 FLOPs 差异。
"""
# ── 参数量 ────────────────────────────────────────────
# 标准 FFN:W1(d_model→d_ffn) + W2(d_ffn→d_model)
params_standard = 2 * d_model * d_ffn # 2 × 2048 × 4096
# Action-Expert 方案:大 FFN + 小 Action FFN
params_action_expert = (2 * d_model * d_ffn # 大 FFN(不变)
+ 2 * d_model * d_action_ffn) # 小 Action FFN(新增)
# 新增参数量很少:2 × 2048 × 1024 = 4M,相对于主干几乎可忽略
# ── FLOPs(以 batch_size=1 为例)────────────────────────
n_total = n_context_tokens + n_action_tokens
# 标准方案:所有 token 都走大 FFN
flops_standard = n_total * 2 * d_model * d_ffn * 2
# ↑token数 ↑两层线性 ↑每次乘加2op
# Action-Expert 方案:
flops_context = n_context_tokens * 2 * d_model * d_ffn * 2
flops_action = n_action_tokens * 2 * d_model * d_action_ffn * 2
flops_expert = flops_context + flops_action
speedup = flops_standard / flops_expert
print(f"标准方案 参数量:{params_standard/1e6:.1f}M")
print(f"Expert 方案 参数量:{params_action_expert/1e6:.1f}M(+{(params_action_expert-params_standard)/1e6:.1f}M)")
print()
print(f"标准方案 FLOPs:{flops_standard/1e9:.2f}G")
print(f"Expert 方案 FLOPs:{flops_expert/1e9:.2f}G")
print(f"理论加速比:{speedup:.2f}×")
print()
print("直觉:轨迹 token 只占序列的 16/(512+16) ≈ 3%,")
print(" 但 Action FFN 是大 FFN 的 1/4,")
print(" 所以整体 FFN FLOPs 节省约 3% × 75% ≈ 2.25%。")
print(" 论文报告的 2.6× 是针对 Action FFN 模块本身的加速,")
print(" 不是全模型加速。")
# 实际上 2.6× 指的是 Action FFN 这个操作本身的延迟下降:
# 原来每个 token 都调用大 FFN kernel,现在轨迹 token 调用小 FFN kernel
# kernel launch + 矩阵乘法的延迟同时下降 → 2.6×
# ----------------------------------------------------------
# 6. 为什么精度不降反升?
# ----------------------------------------------------------
#
# 论文里提到一个反直觉的现象:
# Action FFN 从 d_ffn=4096 压缩到 d_action_ffn=1024 之后,
# meanSADE(平均 SAD 误差,越小越好)反而变好了。
#
# 可能的解释:
#
# ① 过参数化导致"记忆"而非"泛化"
# 大 FFN 有足够的容量把训练集里每条轨迹"背下来",
# 小 FFN 容量受限,被迫学习更通用的坐标变换规律。
#
# ② 正则化效应
# 小 FFN 相当于对轨迹 token 的表示施加了隐式的信息瓶颈,
# 逼着模型把高层语义(该往哪走)完全依赖 Attention 解决,
# 而 FFN 只做纯粹的坐标数值变换。
# 两者分工更清晰,反而更健康。
#
# ③ 梯度干净
# 大 FFN 里,语言和轨迹 token 共享参数,梯度相互干扰。
# 分离之后,Action FFN 的梯度只来自轨迹相关的损失,
# 更新方向更纯粹。
用一张图把整个层的结构说清楚:
Transformer 某一层的输入序列(共 ~528 个 token):
┌──────────────────────────────────┬────────────────┐
│ 视觉/导航/状态/目标点 token │ 轨迹 token │
│ ~512 个,语义复杂 │ 固定 16 个 │
└──────────────────────────────────┴────────────────┘
↓ 自注意力(所有 token 统一计算,不分路由)
↓
┌────────┴────────┐
│ 按类型路由 │
↓ ↓
StandardFFN ActionExpertFFN
d_ffn = 4096 d_ffn = 1024(1/4 大小)
↓ ↓
└────────┬────────┘
↓ 拼回原位,形状不变
KV 缓存的合并重写
第一层:普通 KV 缓存是什么
标准自回归语言模型(GPT 系列)每次只生成一个 token,是单向因果注意力:
生成第 4 个 token 时,它只能看 token 1、2、3
所以 token 1、2、3 的 K 和 V 在之前已经算过了,直接缓存复用即可
这就是 KV 缓存:已经算过的 K/V 不重复算,每步只算新 token 的 K/V。
第二层:掩码扩散为什么破坏了 KV 缓存
掩码扩散用的是双向注意力——每个 token 能看到序列里所有其他 token。而且每一步都有若干个 assessments 变成具体 token,序列在变化:
步骤 0: assessments assessments assessments assessments assessments assessments
步骤 1: assessments x₂ assessments x₄ assessments assessments ← 位置 2、4 被填入
步骤 2: x₁ x₂ x₃ x₄ assessments assessments ← 位置 1、3 被填入
问题来了:步骤 1 中位置 2 从 assessments 变成了 x₂,这个变化会影响所有其他位置对它的注意力结果,理论上其他位置的 K/V 都要重算。
朴素做法:每步都把整个序列的 K/V 从头计算一遍。这就是为什么表格里"无优化"的延迟是 14.7ms。
第三层:合并重写的关键洞察
虽然双向注意力让 token 互相影响,但有一个重要的局部性质:
某个 token 在第 l 层的 K 和 V,只依赖它自己在第 l-1 层的输出,不直接依赖同层其他 token 的值。
第 l 层:
K_i^l = W_K · h_i^(l-1) ← 只取决于 token i 自己的上层输出
V_i^l = W_V · h_i^(l-1) ← 同上
合并重写的近似假设:在同一步里,只有少数位置发生变化,大部分 token 的 K/V 变化很小,可以直接复用上一步的缓存,只对真正发生变化的位置重新计算 K/V 并写回缓存。
这是一个以极小精度损失换取大幅速度提升的近似,论文验证精度影响可忽略不计。
具体实现示意
草稿阶段第 s 步,位置 {3, 7, 12} 刚被 unmask:
旧 KV 缓存(步骤 s-1):
位置 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
K/V [✓] [✓] [✓] [✗] [✓] [✓] [✓] [✗] [✓] [✓] [✓] [✓] [✗] [✓] [✓] [✓]
↑ ↑ ↑
刚变化 刚变化 刚变化
合并重写操作:
位置 3、7、12 → 重新计算 K/V,写回缓存 (3 个位置的计算量)
其余 13 个位置 → 直接复用,零计算
新 KV 缓存(步骤 s):
位置 0 1 2 3' 4 5 6 7' 8 9 10 11 12' 13 14 15
K/V [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓]
KV 缓存合并重写伪代码
# ============================================================
# KV 缓存回退 + 合并重写 伪代码
# 对应论文 Section 5(推理优化部分)
# ============================================================
#
# 问题背景:
# 掩码扩散用双向注意力,每步有若干 token 从 assessments 变成具体值。
# 朴素做法:每步把全序列 K/V 从头算一遍 → 14.7ms/step
# 合并重写:只重算发生变化的位置 → 11.5ms/step,加速 1.28×
#
# 核心近似:
# 当步骤 s 中只有少数位置 Δs 发生变化时,
# 其他位置的 K/V 在 s 和 s+1 步之间变化很小,
# 可以直接复用上一步的缓存值。
# ============================================================
import torch
# ----------------------------------------------------------
# 数据结构:多层 KV 缓存
# ----------------------------------------------------------
class KVCache:
"""
存储所有 Transformer 层的 K 和 V 张量。
shape: [n_layers, batch, n_heads, seq_len, head_dim]
同时维护一个"版本戳",记录每个位置的 token 在哪一步最后更新,
用于 AutoEdit 的回退操作。
"""
def __init__(self, n_layers, batch, n_heads, seq_len, head_dim):
self.K = torch.zeros(n_layers, batch, n_heads, seq_len, head_dim)
self.V = torch.zeros(n_layers, batch, n_heads, seq_len, head_dim)
# 版本戳:position_version[i] = 该位置的 K/V 在第几步被算出来
self.position_version = torch.full((seq_len,), -1, dtype=torch.long)
# 快照栈:用于 AutoEdit 阶段的回退(rollback)
# 每次"保存快照"就往栈里压一帧,回退就弹出
self._snapshots = []
def update_positions(self, changed_positions, new_K, new_V, step):
"""
合并重写:只更新发生变化的位置,其余保持不变。
Args:
changed_positions : List[int],本步发生变化的 token 位置索引
new_K : [n_layers, batch, n_heads, len(changed), head_dim]
new_V : [n_layers, batch, n_heads, len(changed), head_dim]
step : 当前步骤编号
"""
for k, pos in enumerate(changed_positions):
self.K[:, :, :, pos, :] = new_K[:, :, :, k, :]
self.V[:, :, :, pos, :] = new_V[:, :, :, k, :]
self.position_version[pos] = step # 记录版本
def save_snapshot(self):
"""保存当前完整 KV 状态(用于 AutoEdit 回退)。"""
self._snapshots.append({
'K': self.K.clone(),
'V': self.V.clone(),
'version': self.position_version.clone()
})
def rollback(self):
"""
回退到上一个快照。
在 AutoEdit 阶段,如果当前修改使得奖励变差,
可以撤销这一轮的 K/V 更新,重新选择要修改的位置。
"""
assert len(self._snapshots) > 0, "没有可回退的快照"
snap = self._snapshots.pop()
self.K = snap['K']
self.V = snap['V']
self.position_version = snap['version']
# ----------------------------------------------------------
# 核心函数:计算变化位置的新 K/V
# ----------------------------------------------------------
def recompute_kv_for_changed_positions(model, token_sequence,
changed_positions, kv_cache):
"""
只对 changed_positions 上的 token 重新计算 K 和 V,
其余位置直接复用 kv_cache 中已有的值。
这就是"合并重写"的实质:
- 旧缓存(不变位置):直接读取,零计算
- 新计算(变化位置):只算这几个 token,写回缓存
Args:
model : Transformer 模型
token_sequence : 当前完整 token 序列 [batch, seq_len]
changed_positions : 本步发生变化的位置索引列表
kv_cache : 当前 KV 缓存对象
Returns:
updated_kv_cache : 更新后的 KV 缓存
"""
if len(changed_positions) == 0:
return kv_cache # 没有变化,直接返回
# ── Step 1:只取出变化位置的 token embedding ──────────
# 不需要对整个序列做 embedding,只做几个 token
changed_token_ids = token_sequence[:, changed_positions] # [batch, n_changed]
changed_embeddings = model.embed(changed_token_ids) # [batch, n_changed, d_model]
# ── Step 2:逐层计算这些位置的新 K 和 V ───────────────
#
# 关键近似在这里:
# 每层的输入 h_i^(l-1) 理论上依赖上层所有 token 的注意力结果。
# 近似处理:对变化位置,用当前层已有的完整 K/V 做注意力,
# 得到这些位置的新隐状态,再算新的 K/V。
# 对不变位置:直接复用缓存,不做任何计算。
#
new_K_all_layers = []
new_V_all_layers = []
h = changed_embeddings # [batch, n_changed, d_model] 初始为 embedding
for layer_idx, layer in enumerate(model.layers):
# 2-a. 用当前层已缓存的完整 K/V 做 cross-attention:
# 变化位置的 query 去 attend 整个序列(包括不变位置的缓存 K/V)
# 这样变化位置的隐状态就能看到完整的上下文
K_full = kv_cache.K[layer_idx] # [batch, heads, seq_len, head_dim]
V_full = kv_cache.V[layer_idx] # [batch, heads, seq_len, head_dim]
# 只对变化位置做 attention(query 来自变化位置,key/value 来自完整缓存)
Q_changed = layer.W_Q(h) # [batch, n_changed, d_model]
h = layer.cross_attend(Q_changed, K_full, V_full) # [batch, n_changed, d_model]
h = layer.ffn(h) # [batch, n_changed, d_model]
# 2-b. 用更新后的隐状态计算这些位置的新 K 和 V
new_K = layer.W_K(h) # [batch, n_changed, d_model] → reshape to [batch, heads, n_changed, head_dim]
new_V = layer.W_V(h)
new_K_all_layers.append(new_K)
new_V_all_layers.append(new_V)
# ── Step 3:合并写回缓存(只写变化的位置)─────────────
# "合并":新算的几个位置 + 原来缓存的其余位置 = 完整 KV 缓存
new_K_tensor = torch.stack(new_K_all_layers) # [n_layers, batch, heads, n_changed, head_dim]
new_V_tensor = torch.stack(new_V_all_layers)
kv_cache.update_positions(changed_positions, new_K_tensor, new_V_tensor,
step=current_step)
return kv_cache
# ----------------------------------------------------------
# 草稿阶段:带合并重写的掩码扩散解码
# ----------------------------------------------------------
def draft_with_kv_merge_rewrite(model, context_tokens, S_draft=3):
"""
草稿阶段的推理循环。
每步只 unmask 少数 token,用合并重写更新 KV 缓存。
Args:
context_tokens : 场景上下文 token(视觉/导航/状态)[batch, n_ctx]
S_draft : 草稿阶段总步数
Returns:
draft_tokens : 填满的轨迹 token 序列 [batch, L]
kv_cache : 最终 KV 缓存(供 AutoEdit 阶段复用)
"""
L = 16 # 轨迹 token 数量
batch = context_tokens.shape[0]
# ── 初始化:轨迹部分全 assessments,和上下文拼在一起 ────────
traj_tokens = torch.full((batch, L), MASK_TOKEN_ID)
full_sequence = torch.cat([context_tokens, traj_tokens], dim=1)
# full_sequence: [batch, n_ctx + L]
# ── 计算上下文前缀的 KV 缓存(只算一次,整个推理过程复用)
# 上下文 token 用因果注意力,所以这部分可以完美缓存
kv_cache = compute_prefix_kv(model, context_tokens) # 见论文"共享前缀 KV 缓存"
# ── 初始化轨迹 token 的 KV(全 assessments 的 embedding)─────
mask_embedding = model.embed(traj_tokens) # [batch, L, d_model]
initial_K, initial_V = model.compute_kv(mask_embedding)
kv_cache.update_positions(list(range(L)), initial_K, initial_V, step=0)
# ── 草稿循环 ─────────────────────────────────────────────
unmasked_positions = set() # 已经填好的位置
for s in range(S_draft):
# 用当前完整 KV 缓存做前向,得到所有位置的 logits
# (只对仍是 assessments 的位置采样,已填好的位置直接保留)
logits = model.forward_with_kv(full_sequence, kv_cache)
# logits: [batch, n_ctx + L, vocab_size]
traj_logits = logits[:, -L:, :] # 只取轨迹部分的 logits
probs = softmax(traj_logits, dim=-1)
# 选出本步要填入的位置(最高置信度的那批 assessments)
n_to_unmask = compute_unmask_count(s, S_draft, n_masked=(L - len(unmasked_positions)))
still_masked = [i for i in range(L) if i not in unmasked_positions]
confidences = {i: probs[0, i, :].max().item() for i in still_masked}
newly_unmasked = sorted(still_masked,
key=lambda i: -confidences[i])[:n_to_unmask]
# 采样并填入
for pos in newly_unmasked:
sampled_token = sample_from(probs[0, pos])
full_sequence[0, -L + pos] = sampled_token
unmasked_positions.add(pos)
# ════════════════════════════════════════════════
# 合并重写:只对新填入的位置重算 K/V
# 之前已 unmask 的位置 + 仍是 assessments 的位置 → 直接复用缓存
# ════════════════════════════════════════════════
kv_cache = recompute_kv_for_changed_positions(
model = model,
token_sequence = full_sequence,
changed_positions = [n_ctx + pos for pos in newly_unmasked],
# ↑ 转换为全序列索引(轨迹在末尾)
kv_cache = kv_cache
)
# 本步只重算了 n_to_unmask 个位置的 K/V,其余全部复用!
draft_tokens = full_sequence[:, -L:] # 提取轨迹 token
return draft_tokens, kv_cache
# ----------------------------------------------------------
# AutoEdit 阶段:带回退 + 合并重写的轨迹修正
# ----------------------------------------------------------
def autoedit_with_rollback(model, draft_tokens, context_tokens,
kv_cache, S_edit=3):
"""
AutoEdit 阶段的推理循环。
比草稿阶段多了一个"回退"能力:
如果某次修改后置信度反而更低,可以撤销这次修改。
Args:
draft_tokens : 草稿轨迹 token [batch, L]
context_tokens : 场景上下文 token [batch, n_ctx]
kv_cache : 草稿阶段结束时的 KV 缓存(直接继承,无需重算上下文!)
S_edit : AutoEdit 迭代轮数
Returns:
edited_tokens : 修正后的轨迹 token [batch, L]
"""
L = 16
full_sequence = torch.cat([context_tokens, draft_tokens], dim=1)
for k in range(S_edit):
# ── 保存当前快照,以备回退 ────────────────────────
kv_cache.save_snapshot()
prev_sequence = full_sequence.clone()
# ── 用当前完整 KV 做前向,得到每个位置的替换建议 ──
logits = model.forward_with_kv(full_sequence, kv_cache)
traj_logits = logits[:, -L:, :]
probs = softmax(traj_logits, dim=-1)
# 每个位置的"替换 token"和"置信度"
x_hat = probs.argmax(dim=-1) # [batch, L]
confidence = probs.max(dim=-1).values # [batch, L]
# ── 构造 commit mask:选低置信度的非目标点位置 ────
goal_positions = get_goal_token_positions()
commit_mask = torch.zeros(L, dtype=torch.bool)
for i in range(L):
if i in goal_positions:
commit_mask[i] = False # 目标点永远不替换
elif confidence[0, i] < CONFIDENCE_THRESHOLD:
commit_mask[i] = True # 置信度低 → 标记为替换
changed_positions = commit_mask.nonzero().squeeze(-1).tolist()
if len(changed_positions) == 0:
# 没有低置信度的位置,不需要修改,丢弃快照直接退出
kv_cache._snapshots.pop()
break
# ── 执行替换 ──────────────────────────────────────
for pos in changed_positions:
full_sequence[0, -L + pos] = x_hat[0, pos]
# ── 合并重写:只更新被替换位置的 K/V ─────────────
kv_cache = recompute_kv_for_changed_positions(
model = model,
token_sequence = full_sequence,
changed_positions = [-L + pos for pos in changed_positions],
kv_cache = kv_cache
)
# ── 可选的质量检查:如果修改后置信度整体更低,回退 ─
new_logits = model.forward_with_kv(full_sequence, kv_cache)
new_confidence = softmax(new_logits[:, -L:, :], dim=-1).max(dim=-1).values.mean()
old_confidence = confidence.mean()
if new_confidence < old_confidence * 0.95:
# 修改让模型"更不确定"了,撤销这一轮
kv_cache.rollback()
full_sequence = prev_sequence
else:
# 修改有效,丢弃快照(不再需要回退到这个点)
kv_cache._snapshots.pop()
edited_tokens = full_sequence[:, -L:]
return edited_tokens
# ----------------------------------------------------------
# 合并重写 vs 朴素重算:计算量对比
# ----------------------------------------------------------
#
# 设:
# seq_len = n_ctx + L = 512 + 16 = 528
# 每步变化位置 = n_changed(草稿阶段约 5~6,AutoEdit 阶段约 2~3)
# n_layers = 28
#
# 朴素重算(每步):
# 对完整序列重算所有 K/V
# FLOPs ∝ n_layers × seq_len × d_model²
# = 28 × 528 × 2048² ≈ 124G FLOPs
#
# 合并重写(每步):
# 只对 n_changed 个位置重算 K/V
# FLOPs ∝ n_layers × n_changed × d_model²
# = 28 × 6 × 2048² ≈ 1.4G FLOPs
# (其余 522 个位置直接读缓存,零 FLOPs)
#
# 理论加速:124G / 1.4G ≈ 88×
#
# 实际加速 1.28×(14.7ms → 11.5ms)的原因:
# · GPU kernel 启动开销、内存带宽瓶颈等 overhead 不随 FLOPs 线性缩减
# · KV 缓存的读写本身也有延迟
# · 合并重写引入了额外的 mask 判断和条件写操作
# 但在绝对延迟上节省了 3.2ms/帧,对 31.8ms 的总延迟已是可观的贡献
为什么 K 和 V 只依赖自己的上一层输出?
Transformer 每一层做了什么?
每一层的计算可以拆成两步:
输入:上一层的隐状态矩阵 H^(l-1),形状 [seq_len, d_model]
每一行是一个 token 的向量表示
Step 1:自注意力(Attention)
Step 2:前馈网络(FFN)
输出:本层的隐状态矩阵 H^(l),形状 [seq_len, d_model]
关键在于:K 和 V 是在哪里算出来的?
K 和 V 是在第 l 层的注意力计算开始之前,直接从上一层的输出线性变换得到的:
K^(l) = H^(l-1) · W_K ← 矩阵乘法,逐行独立
V^(l) = H^(l-1) · W_V ← 矩阵乘法,逐行独立
Q^(l) = H^(l-1) · W_Q ← 矩阵乘法,逐行独立
逐行独立是关键词。把上面的矩阵乘法写成逐行的形式:
K^(l)_i = h^(l-1)_i · W_K ← 只用了第 i 行(第 i 个 token 的向量)
V^(l)_i = h^(l-1)_i · W_V ← 同上
这个线性变换里完全没有其他 token 的信息,所以:
K^(l)_i和V^(l)_i只依赖h^(l-1)_i,不依赖h^(l-1)_j(j ≠ i)。
那"互相影响"发生在哪里?
发生在注意力的输出计算里,也就是 Q 和 K/V 做完点积之后:
注意力输出:
A^(l)_i = softmax( Q^(l)_i · (K^(l))ᵀ ) · V^(l)
↑ ↑
这里 token i 的 Q 这里混入了所有 token 的 V
去和所有 token 的 K 做点积
所以注意力输出 A^(l)_i 确实包含了所有其他 token 的信息,但 K 和 V 本身在被计算出来的时候,还没有经历这一步混合。
用一张流程图看清楚顺序
第 l 层:
H^(l-1)
│
├─[×W_K]─→ K^(l)_i = h^(l-1)_i · W_K ← 此时只有自己
├─[×W_V]─→ V^(l)_i = h^(l-1)_i · W_V ← 此时只有自己
└─[×W_Q]─→ Q^(l)_i = h^(l-1)_i · W_Q ← 此时只有自己
│ │
└──────┬───────┘
↓
Attention(Q, K, V) ← 此处发生混合
↓
A^(l)_i ← 此时包含所有 token 的信息
↓
FFN(A^(l)_i)
↓
H^(l)_i ← 下一层的输入
用一个比喻来理解
想象一场圆桌会议,每个人(token)在发言前先准备自己的问题卡(Q)、名片(K)、资料包(V)。
准备这三样东西的时候,每个人只看自己昨天(上一层)的笔记,不看别人的。所以名片和资料包只属于自己。
然后会议开始:每个人拿着自己的问题卡去和所有人的名片做匹配,再根据匹配程度加权收集大家的资料包。混合发生在这一步,而不是准备名片和资料包的时候。
相关论文
- [[DiffusionDriveV2 论文阅读笔记]]
- [[DriveTransformer 论文阅读笔记]]
- [[RAD-2 论文阅读笔记]]
- [[World4Drive - 无需感知标注的端到端世界模型]]
总结一句话:ReflectDrive-2 教会了自动驾驶模型像人类老司机一样——先打草稿,再反思修改,最后稳稳开车。 而这种反思能力之所以有效,不是因为模型"天生就会改",而是因为强化学习让起草和修改成了真正的搭档,共同对最终结果负责。