looyifan / ReflectDrive-2 论文阅读笔记

Created Sat, 09 May 2026 00:00:00 +0000 Modified Fri, 15 May 2026 14:43:15 +0800

来自理想汽车团队的最新研究,2026年5月发表

🎬 摘要:先立个 flag

想象一下,你是一位新司机,脑子里的第一个念头(“踩油门超过去!")并不总是最好的选择。聪明的做法是:先打个草稿,再审视一遍,改掉不合理的地方,最终才真正打方向盘。

ReflectDrive-2 就是这样一位"会反思的自动驾驶大脑”。它的核心流程是三步走:决策 → 起草轨迹 → 自我修改(Reflect),全程在同一套"离散 token 空间"里完成,不需要额外的修改网络。

论文最核心的发现是:光靠监督学习训练出来的"修改器"只会让性能提升 +0.3分(PDMS,一个综合驾驶质量分),但当引入强化学习让"起草"和"修改"联合优化之后,提升蹿到了 +1.9分。最终在 NAVSIM 基准测试中,仅用摄像头输入就达到了 91.0 PDMS,最优6选1时更达到了与人类持平的 94.8 PDMS,同时在 NVIDIA Thor 芯片上单帧只需 31.8ms


第一章:导言 —— 司机的"三步曲"

论文开篇抛出了 ReflectDrive-2 的核心工作流:“决策(Decision)—— 起草(Draft)—— 反思(Reflect)”

想象一下你正在开车:

  1. 决策: 环顾四周的摄像头画面、导航信息和自身状态,你心里选定了一个大方向(生成一个 Goal Token 目标词元)。
  2. 起草: 你在脑海中快速画出一条大致的行车路线(通过掩码离散扩散并行解码出一条蓝色的初始轨迹)。
  3. 反思(AutoEdit): 然发现前面有个小坑!你没有全盘否定路线,而是在脑海中对那一段方向盘微调了一下(在同一模型内就地修改词元,得到绿色的最终安全计划)。

这种在统一离散空间里的自我修改能力,不需要额外挂载"外挂"(如辅助微调网络),模型自己就能搞定!


第二章:相关工作 —— 站在巨人的肩膀上找痛点

  • 端到端与VLA(视觉-语言-动作)规划: 过去的端到端模型通常缺乏中间纠错机制。虽然有些VLA大模型引入了新思路,但 ReflectDrive-2 直接用"掩码离散扩散"取代了传统的自回归,几轮并行解码就能拿下一整条轨迹!

  • 扩散策略与强化学习(RL): 之前最接近的工作是 DriveFine(2026),但它的"起草员"和"修改员"是分开训练的,就像写手和编辑各拿各的工资,配合不够默契。ReflectDrive-2 的杀手锏在于:用强化学习(RL)把"起草"和"修改"绑定在一个回合里(composed rollout)。两个人共享最终的驾驶奖励,真正做到了"有福同享,有难同当",从而让编辑能力突飞猛进。


第三章:预备知识 —— 把开车变成"拼乐高"

问题设定

系统的输入是全景相机画面、导航指令和自车状态;输出则是离散的"轨迹词元"(Tokens)。

掩码离散扩散 (Masked Discrete Diffusion)

想象未来的轨迹是由一块块包含鸟瞰图(BEV)坐标的"乐高积木"拼成的。模型一开始拿到的是一堆被打上 assessments(马赛克)的未知积木,然后通过一个双向 Transformer,结合多模态上下文,一步步把马赛克替换成真实的坐标积木,实现有选择性的重新生成。


第四章:核心方法 —— ReflectDrive-2 架构大揭秘

这是这篇论文的灵魂所在,分为五步绝杀:

4.1 系统全貌

将"目标提议、轨迹起草、词元级轨迹纠错"全部统一在离散表示中。

4.2 条件目标掩码轨迹扩散

模型不是死板地给出唯一终点,而是预测一个"目标点后验概率分布",通过 top-k 采样和非极大值抑制(NMS)选出候选目标点,并把它们作为条件,并行起草完整的行驶轨迹。

4.3 AutoEdit 轨迹修正

这是核心魔法!模型具备感知结构的扰动能力,在推理阶段,直接识别出轨迹中那些"不靠谱"的局部 Token,并在原地将它们重写(Rewrite),完成轨迹的微调。

4.4 约束感知的监督目标

在监督学习阶段,引入了"可行驶区域场损失"(Drivable-area field loss),教模型不要把车开到马路牙子或草坪上。

4.5 在"起草-编辑"回放上的强化学习

终极武器! 系统把奖励(Reward)发给"编辑后"的最终轨迹,并通过策略梯度把功劳分配给前期的起草和后期的修改阶段。实验证明,这让模型的纠错能力产生了质的飞跃。


第五章:高效推理 —— 把速度压榨到极致

自动驾驶是人命关天的事,计算再复杂也不能卡顿!作者在 NVIDIA Thor 芯片 上硬生生把平均延迟压到了 31.8 毫秒

优化手段汇总

优化手段 延迟变化 加速比
共享前缀 KV 缓存复用 0.28ms → 0.08ms 3.5×
KV 缓存回退 + 合并重写 14.7ms → 11.5ms 1.28×
Action-Expert FFN(隐层 4096→1024) 2.47ms → 0.95ms 2.6×
融合 CUDA Unmask 内核 0.45ms → 0.06ms 7.5×
交替步解码(ASD)时序 AutoEdit 26.2ms → 7.6ms 3.4×
端到端平均帧延迟 45.0ms → 31.8ms 1.42×

交替步解码 (Alternating Step Decode)

这个设计太聪明了!因为连续两帧画面的路况其实差不多,模型不需要每帧都"从零思考"。它采用了**全步(Full-step)轻量步(Lite-step)**交替的策略。全步走一遍"决策-起草-反思";轻量步则直接拿上一帧的轨迹往前挪一挪,做个极简版的 Token 到 Token 的 AutoEdit 快速更新。这就好比司机眨了下眼,只需凭肌肉记忆微调方向盘,省下了巨大的算力!


第六章:实验 —— 用数字说话

6.1 实验设置

  • 数据集:NAVSIM(基于 nuPlan),训练集 1192 个场景,测试集 136 个场景,任务是预测 4 秒、2Hz 采样的自车轨迹
  • 评估指标 PDMS:综合了无责任碰撞(NC)、可行驶区域合规(DAC)、碰撞时间(TTC)、舒适性(Comf.)、自车进度(EP)
  • 模型:0.7B 离散扩散语言主干 + 0.1B ViT 视觉编码器,全部微调

6.2 RL 对 AutoEdit 的增益

训练设置 无 AutoEdit 有 AutoEdit 增益
只有 DLM 损失 84.8 85.0 +0.2
+ 可行驶区域场损失 87.2 87.3 +0.1
+ AutoEdit 监督训练 87.7 88.0 +0.3
+ 全流程 RL 89.1 91.0 +1.9

结论一目了然:没有 RL,AutoEdit 只是"摆设";有了 RL,AutoEdit 变成真正有价值的能力。

6.3 与其他方法对比

方法 输入 PDMS
GoalFlow 相机+激光雷达 90.3
ReCogDrive 仅相机 90.8
ReflectDrive-2(我们) 仅相机 91.0
ReflectDrive-2(6选1 oracle) 仅相机 94.8
人类参考 94.8

最亮眼的指标是EP(自车进度)= 89.4,所有方法中最高——说明车走得更积极,同时 DAC 和舒适性依然保持高位,做到了"进可攻守可守"。


第七章:结论 —— 学会"反思"的司机才是好司机

ReflectDrive-2 的核心启示是:自我纠正不是免费的午餐。一个训练好的编辑器,如果和起草器没有被共同的目标绑定,那它在实际中几乎毫无用处。只有当两者通过强化学习共享终局奖励,起草器才会"学会"生成可改善的草稿,编辑器才会"学会"做出真正有价值的修改。

局限与未来方向

  • 当前轨迹用固定分辨率的 BEV 坐标 token,精度受 bin 大小限制。未来可以考虑更细的词表、残差偏移或混合离散-连续动作头
  • RL 阶段的奖励是闭环规划代理分,尚非真实世界完整目标;接入更高保真仿真器和更丰富的安全奖励有望进一步改善
  • AutoEdit 的扰动目前只覆盖纵向和横向,未来可以扩展到交互层面的失败(让行时机、被加塞响应、间距选择)

附录:关键技术概念详解

离散的"轨迹词元"是什么?

第一步:轨迹是什么

自动驾驶规划的输出是未来几秒内车该去哪,具体表示为一串路径点(waypoints):

(x₁, y₁) → (x₂, y₂) → ... → (x₈, y₈)

这 8 个点描述了车辆在鸟瞰图(BEV)坐标系中未来 4 秒的行驶轨迹,每 0.5 秒一个点。

第二步:为什么要"离散化"?

坐标本来是连续实数,比如 x = 3.742 米。但论文把坐标空间切成格子,就像经纬度变成邮政编码一样:

真实坐标 x = 3.742m
         ↓  量化
格子编号  x_bin = 47  (第47个格子)

这个格子编号就是一个词元(token),和语言模型里的字词 token 是同一个概念。

每个路径点有 1 个纵向 token + 1 个横向 token,8 个路径点就是 16 个 token,构成整条轨迹的"词元序列"。

第三步:和语言模型的词元有什么相似之处?

语言模型 ReflectDrive-2
token 是什么 词/子词编号 BEV 坐标格子编号
序列长度 几十~几千 固定 16 个
词表大小 ~50,000 坐标格子数
生成方式 自回归/扩散 掩码扩散并行生成

两者在数学结构上完全一致,所以才能直接套用离散扩散语言模型(LLaDA 等)的整套框架。

第四步:离散化带来了什么好处?

正是因为轨迹变成了一串 token,才能做到:

  • 并行生成:不用一个点一个点地算,所有 token 同时从 assessments 开始,几轮并行 unmask 就完成
  • 原地编辑:想改某段轨迹?直接把那几个 token 换掉,就像文字编辑器里改几个字,不需要重新生成整条轨迹
  • AutoEdit 成为可能:编辑操作和生成操作用的是同一套模型、同一套词表,没有任何"模态鸿沟"

一个直觉类比

把轨迹想象成一首歌的简谱。连续坐标就像模拟音频波形,你没法直接"改几个音符";而离散 token 就像乐谱上的音符编号,你可以精确地找到第3小节第2个音,把它从 5 改成 6,其他音符完全不受影响。


掩码离散扩散详解

第一层:扩散模型是在解决什么问题?

先忘掉"扩散"这个词,想象一个更简单的问题:

我想让 AI 画一张猫的图片。

AI 怎么"生成"一张图?它不可能凭空变出像素。它需要某种从混乱到有序的过程。

扩散模型的核心思路是:

训练阶段,教 AI 学会"如何把一张乱图还原成好图":

好图 → 加噪声 → 加更多噪声 → ... → 纯噪声
                                        ↑
                              AI 学会反着走这条路

推理阶段,从一堆纯噪声出发,AI 一步一步去噪,最终生成一张清晰的猫图。

这就是扩散模型的本质:学习"去噪"这件事,然后用去噪来生成内容。

第二层:连续扩散 vs 离散扩散

上面说的是连续扩散——图片的像素是连续的实数,噪声是高斯噪声(正态分布的随机数)。

但如果你处理的不是图片,而是文字呢?

文字是离散的:"猫" 这个字不能加上 0.3 变成 "猫.3",这没有意义。你只能从一个词跳到另一个词。

所以就有了离散扩散,它把"加噪声"换成了"替换成随机的其他词/符号":

"今天天气真好"
→ "今天[随机词]真好"
→ "[随机词][随机词]真好"
→ "[随机词][随机词][随机词][随机词]"

AI 的任务同样是学会反着走:从一堆乱七八糟的词,还原出原本有意义的句子。

第三层:掩码扩散 —— 用 assessments 代替随机词

离散扩散有很多种"加噪声"的方式。掩码扩散(Masked Diffusion) 选择了最简单粗暴的一种:

不替换成随机词,直接盖住——换成一个特殊符号 assessments

"今天天气真好"
→ "今天 assessments 真好"
→ " assessments  assessments 真好"
→ " assessments  assessments  assessments  assessments "

你可能觉得这很眼熟——没错,这和 BERT 的完形填空几乎是同一个思路!

AI 的任务:看到带 assessments 的句子,猜出被盖住的是什么。

第四层:生成时怎么用?

训练完之后,想生成内容,就把"正向加噪"反过来走:

从全部都是 assessments 开始,每一轮让 AI 填一些空:

第0轮: assessments   assessments   assessments   assessments   assessments   assessments
         ↓ AI预测每个位置的候选词,选最有把握的填入
第1轮:今天   assessments   assessments   真     assessments   assessments
         ↓
第2轮:今天   天气    assessments  真    好了    assessments
         ↓
第3轮:今天   天气   真的    真    好了   啊

每轮都并行处理所有位置,所以比逐词生成的自回归模型快得多。

这个"每轮选最有把握的 token 填入"的策略,就是 MaskGIT 提出的置信度驱动解码,也是 ReflectDrive-2 起草轨迹的核心机制。

第五层:为什么天然支持"原地编辑"?

这是掩码扩散最妙的地方,也是 ReflectDrive-2 选择它的根本原因。

生成完成之后,你得到了一个完整的 token 序列。如果你觉得某几个 token 不对,想改——

你只需要把那几个 token 重新盖成 assessments,然后让 AI 重新填空!

原轨迹:  x₁  y₁  x₂  y₂  x₃  y₃  x₄  y₄
发现偏了:          ↑这两个位置有问题
重新盖住:x₁  y₁  assessments  assessments  x₃  y₃  x₄  y₄
AI 重填:  x₁  y₁  x₂'  y₂'  x₃  y₃  x₄  y₄  ✓

其他 token 完全不受影响,模型不需要重新跑一遍,也不需要额外的网络。这就像用橡皮擦掉草稿上的几个字,重新写——而不是把整页纸揉掉重写。

相比之下:

  • 连续扩散想修改,得把整个去噪过程重来
  • 自回归模型想修改,得从出问题的那个词开始把后面全部重新生成

总结:一张图记住全部

【训练】
好的序列  →  随机盖住一些token  →  让AI学会填空

【生成】
全 assessments  →  AI并行填最有把握的  →  几轮后得到完整序列

【编辑】(掩码扩散独有!)
完整序列  →  把想改的位置重新盖住  →  AI重新填空  →  局部修正完毕

AutoEdit 如何识别"不靠谱"的局部 Token?

这个问题分两个层面来回答。

第一个层面:推理时靠"模型置信度"

AutoEdit 不像草稿阶段那样从 assessments 开始填空。它的输入是一条已经完整的轨迹 token 序列,然后对每个位置都预测一个"替换 token":

当前轨迹:  x₁  y₁  x₂  y₂  x₃  y₃  x₄  y₄
               ↓ 模型对每个位置都预测替换值
预测替换:  x₁' y₁' x₂' y₂' x₃' y₃' x₄' y₄'

但不是所有位置都会真正被替换——这就是"识别不靠谱 token"的问题所在。

模型对每个 token 位置输出的不是一个确定的值,而是一个概率分布(softmax 输出):

位置 x₃ 的预测分布:
  格子46号:0.72  ← 最高概率
  格子47号:0.18
  格子48号:0.07
  其他格子:0.03

置信度高 = 模型对某个格子非常笃定(概率集中)

置信度低 = 模型很纠结,概率分散在好几个格子上

AutoEdit 选择置信度低的位置来替换——因为模型自己都不确定当前这个 token 对不对,说明这个位置可能有问题。

置信度:  高    高    低    低    高    高    高    高
               ↓ 只替换低置信度的位置
结果:   x₁   y₁   x₂'  y₂'  x₃   y₃   x₄   y₄

目标点 token(行为锚点)永远不替换,始终锁定。

第二个层面:训练时靠"结构感知扰动"

光靠置信度其实有个问题:模型怎么知道当前这个 token 该不该被替换?

如果 AutoEdit 从来没见过"有问题的轨迹长什么样",它的置信度判断就没有意义——它可能对一条偏出车道的轨迹也信心满满。

这就是结构感知扰动(SAP)训练的价值:

训练时,故意喂给模型两类典型错误轨迹

纵向扰动(进度失误):
  原轨迹走了10米 → 扰动后只走6米(刹车太早)
  原轨迹走了10米 → 扰动后走了14米(刹车太晚)

横向扰动(方向偏移):
  原轨迹直行 → 扰动后整体旋转5度(车道偏移)

然后要求模型把扰动轨迹直接映射回干净轨迹

这个训练过程让模型学会了:

“当我看到一条在纵向/横向上有规律性偏差的轨迹时,我应该对哪些 token 没信心,以及应该改成什么。”

两个层面怎么配合?

可以这样理解它们的分工:

SAP 训练塑造了模型内部的"错误感知能力"——它让模型的概率分布在面对有问题的轨迹时,自然地对相关 token 产生低置信度。

置信度筛选是推理时的执行机制——用模型自己暴露出来的不确定性,决定哪些 token 值得替换。

训练阶段:
  喂扰动轨迹 → 模型学会"这类错误我应该不确定" → 置信度分布变得有意义

推理阶段:
  喂草稿轨迹 → 模型对有问题的位置置信度低 → 筛出来替换

一个直觉类比

想象你是一位改卷老师,改了大量有规律错误的卷子(比如总是把加法做成减法)。久而久之,你扫一眼答案就能感觉到"这道题的答案看起来不对劲"——这种直觉来自训练。

AutoEdit 也一样:见过足够多的纵向偏移、横向偏移轨迹后,它在内部形成了对"轨迹哪里不对劲"的直觉,并通过置信度把这种直觉暴露出来,再由筛选机制决定是否动手修改。

AutoEdit 训练阶段伪代码

# ============================================================
#  ReflectDrive-2 · AutoEdit 训练阶段伪代码
#  对应论文 Section 3.2 / 4.3 / 4.4
# ============================================================
#
#  符号说明(与论文保持一致)
#  z0      : 干净的连续坐标路径点序列  [(x1,y1),...,(x8,y8)]
#  x0      : z0 离散化后的 token 序列  [x1,y1,...,x8,y8]  长度 L=16
#  x̃0      : 扰动轨迹的 token 序列
#  c       : 多模态上下文 (视觉tokens, 导航指令, 自车状态)
#  pθ      : 共享的条件 token 模型(Transformer 主干)
#   assessments  : 掩码符号,用于草稿阶段
#  t       : 掩码比例 ∈ [0, 1],决定盖住多少 token
# ============================================================


import random
import math


# ----------------------------------------------------------
# 工具函数
# ----------------------------------------------------------

def tokenize(waypoints):
    """
    把连续坐标路径点量化为离散 token 序列。
    每个路径点 (x, y) → 两个整数 token (x_bin, y_bin)。
    8 个路径点 → 长度 16 的 token 序列。
    """
    tokens = []
    for (x, y) in waypoints:
        tokens.append(quantize(x, axis='longitudinal'))  # 纵向坐标 bin
        tokens.append(quantize(y, axis='lateral'))        # 横向坐标 bin
    return tokens  # 长度 L = 16


def detokenize(tokens):
    """tokenize 的逆操作:token 序列 → 连续坐标路径点。"""
    waypoints = []
    for i in range(0, len(tokens), 2):
        x = dequantize(tokens[i],   axis='longitudinal')
        y = dequantize(tokens[i+1], axis='lateral')
        waypoints.append((x, y))
    return waypoints


# ----------------------------------------------------------
# Step 1 · 结构感知扰动(Structure-Aware Perturbation, SAP)
#          论文公式 (5)(6)
# ----------------------------------------------------------

def structure_aware_perturbation(z0, beta_range=(0.7, 1.3), alpha_max=0.1):
    """
    对干净轨迹 z0 施加两类结构化扰动,模拟驾驶中最常见的错误模式。

    Args:
        z0         : 干净路径点列表,[(x1,y1),...,(x8,y8)]
        beta_range : 纵向进度缩放系数范围 (βmin, βmax)
        alpha_max  : 横向旋转角度上限 αmax(弧度)

    Returns:
        z̃0 : 扰动后的路径点列表
    """

    # ── 扰动类型随机选一种(或两种都加)──────────────────────

    perturbation_type = random.choice(['longitudinal', 'lateral', 'both'])

    z_tilde = [wp for wp in z0]   # 先复制一份

    # ── 纵向进度扰动:沿弧长缩放 ────────────────────────────
    #    论文公式 (5):z̃i = Interp(z0, β·dᵢ)
    #    β < 1 → 进度不足(刹车过早)
    #    β > 1 → 进度超前(刹车过晚 / 超速)
    if perturbation_type in ('longitudinal', 'both'):

        beta = random.uniform(*beta_range)

        # 计算每个路径点的弧长累积值
        arc_lengths = compute_arc_lengths(z0)          # [d1, d2, ..., d8]

        # 用缩放后的弧长重新插值轨迹
        scaled_arc = [beta * d for d in arc_lengths]   # β·dᵢ
        z_tilde = interpolate_by_arc(z0, scaled_arc)   # 沿原轨迹重新采样


    # ── 横向方向扰动:在自车坐标系中旋转 ───────────────────
    #    论文公式 (6):z̃i = R(α)·zᵢ
    #    α 为随机旋转角,产生整体侧向偏移,但保持轨迹平滑性
    if perturbation_type in ('lateral', 'both'):

        alpha = random.uniform(-alpha_max, alpha_max)

        cos_a, sin_a = math.cos(alpha), math.sin(alpha)

        z_tilde = [
            (cos_a * x - sin_a * y,
             sin_a * x + cos_a * y)
            for (x, y) in z_tilde
        ]

    return z_tilde   # z̃0:扰动后的连续坐标路径点


# ----------------------------------------------------------
# Step 2 · AutoEdit 损失(LSAP)
#          论文公式 (7)(8)
# ----------------------------------------------------------

def compute_autoedit_loss(model, x0, context):
    """
    训练 AutoEdit:给定扰动 token 序列,预测干净 token 序列。

    注意:AutoEdit 不使用 assessments!
    输入是完整的(扰动后的)具体 token,目标是干净的具体 token。
    这与草稿阶段的掩码填空训练完全不同。

    Args:
        model   : 共享的条件 token 模型 pθ
        x0      : 干净轨迹的 token 序列,长度 L=16
        context : 多模态上下文 c = (视觉tokens, 导航, 自车状态)

    Returns:
        L_SAP : AutoEdit 纠错损失(标量)
    """

    L = len(x0)   # = 16

    # 2-a. 生成扰动后的连续坐标路径点
    z0     = detokenize(x0)                        # token → 连续坐标
    z_tilde = structure_aware_perturbation(z0)      # 施加纵向/横向扰动

    # 2-b. 把扰动路径点重新离散化为 token 序列
    x_tilde = tokenize(z_tilde)                    # 扰动 token 序列 x̃0

    # 2-c. 用模型对扰动 token 序列做预测
    #      输入:x̃0(扰动后的完整具体 token)+ 多模态上下文 c
    #      输出:每个位置上的替换 token 概率分布
    #      论文公式 (7):qθ(· | x̃0, c) = softmax( hθ(x̃0, c) )
    logits = model(x_tilde, context)               # shape: [L, vocab_size]
    probs  = softmax(logits, dim=-1)               # qθ(· | x̃0, c)

    # 2-d. 计算交叉熵损失(对所有 L 个位置平均)
    #      论文公式 (8):L_SAP = -E[ 1/L · Σ log qθ(x⁰ᵢ | x̃0, c) ]
    #      目标是干净 token x0,而非扰动 token x̃0
    L_SAP = 0.0
    for i in range(L):
        clean_token_id = x0[i]                     # 第 i 位的干净 token
        L_SAP += -math.log(probs[i][clean_token_id] + 1e-9)

    L_SAP /= L    # 对序列长度取平均

    return L_SAP


# ----------------------------------------------------------
# Step 3 · 草稿阶段掩码扩散损失(LDLM)
#          论文公式 (1)
# ----------------------------------------------------------

def compute_dlm_loss(model, x0, context):
    """
    训练草稿生成:随机盖住若干 token,让模型预测所有位置的原始 token。

    与 AutoEdit 损失的关键区别:
      - AutoEdit:输入是"有规律错误"的具体 token
      - DLM    :输入是"随机盖 assessments"的不完整序列

    Args:
        model   : 共享的条件 token 模型 pθ
        x0      : 干净轨迹的 token 序列
        context : 多模态上下文 c

    Returns:
        L_DLM : 掩码扩散损失(标量)
    """

    L = len(x0)   # = 16

    # 3-a. 随机采样掩码比例 t ∈ [0, 1]
    t = random.uniform(0, 1)

    # 3-b. 按比例 t 独立随机将每个 token 替换为 assessments
    x_masked = []
    for token in x0:
        if random.random() < t:
            x_masked.append(MASK_TOKEN_ID)   # 盖住
        else:
            x_masked.append(token)           # 保留

    # 3-c. 模型对所有位置(包括未盖住的)做预测
    #      注意:论文选择"全位置监督",而非仅对 assessments 位置监督
    #      实验证明全位置监督让训练更稳定、草稿更连贯
    logits = model(x_masked, context)        # shape: [L, vocab_size]
    probs  = softmax(logits, dim=-1)

    # 3-d. 对所有 L 个位置计算交叉熵(目标均为干净 token x0)
    #      论文公式 (1):L_DLM = -E[ 1/L · Σ log pθ(x⁰ᵢ | xt, c) ]
    L_DLM = 0.0
    for i in range(L):
        L_DLM += -math.log(probs[i][x0[i]] + 1e-9)

    L_DLM /= L

    return L_DLM


# ----------------------------------------------------------
# Step 4 · 可行驶区域场损失(Lfield)
#          论文公式 (11)(12)(13)(14)
# ----------------------------------------------------------

def compute_field_loss(model, x0, context, dac_cost_field):
    """
    空间惩罚:让模型在概率层面就"远离"不可行驶区域。

    Args:
        model          : 共享的条件 token 模型 pθ
        x0             : 干净轨迹的 token 序列
        context        : 多模态上下文 c
        dac_cost_field : 可行驶区域代价场 C ∈ R^[H×W]
                         非可行驶区域 = 高代价,可行驶区域 = 0

    Returns:
        L_field : 可行驶区域场损失(标量)
    """

    # 4-a. 用带随机掩码的输入让模型输出 logits(与 DLM 阶段共用)
    logits = model(apply_random_mask(x0), context)   # shape: [L, vocab_size]

    L_field = 0.0

    # 4-b. 对每个路径点 t(共 8 个)计算空间惩罚
    for t in range(8):

        x_logit = logits[2*t]       # 第 t 个路径点的纵向 logit
        y_logit = logits[2*t + 1]   # 第 t 个路径点的横向 logit

        p_x = softmax(x_logit)      # 纵向坐标的概率分布 p^(t)_x  shape:[W]
        p_y = softmax(y_logit)      # 横向坐标的概率分布 p^(t)_y  shape:[H]

        # 4-c. 构造联合空间分布(外积)
        #      论文公式 (11):p^(t)_xy[i,j] = p^(t)_x[i] · p^(t)_y[j]
        #      假设纵横坐标独立——近似但够用
        p_xy = outer_product(p_x, p_y)   # shape: [H, W]

        # 4-d. 用代价场加权的对数障碍函数计算惩罚
        #      论文公式 (12):L_field = Σ_t Σ_{i,j} -log(1 - p^(t)_xy[i,j]) · C[i,j]
        #
        #      直觉:
        #        · C[i,j] 大(靠近不可行驶区域)→ 这一项权重大
        #        · p_xy[i,j] 大(模型对这个位置置信高)→ 惩罚也大
        #        · 两个大值相乘 → 强烈惩罚"高置信度踩线"
        for i in range(H):
            for j in range(W):
                cost = dac_cost_field[i][j]
                if cost > 0:
                    penalty = -math.log(1 - p_xy[i][j] + 1e-9) * cost
                    L_field += penalty

    return L_field


# ----------------------------------------------------------
# Step 5 · 总监督损失 & 训练主循环
#          论文公式 (15)
# ----------------------------------------------------------

def supervised_training(model, dataloader, lambda_SAP=1.0, lambda_field=0.1):
    """
    监督微调(SFT)主循环:同时优化三个目标。

    总损失:L_sup = L_DLM + λ_SAP · L_SAP + λ_field · L_field
    论文公式 (15)

    三个损失的分工:
      L_DLM   → 教模型从 assessments 还原干净轨迹(起草能力)
      L_SAP   → 教模型将扰动 token 映射回干净 token(纠错能力)
      L_field → 在概率层面惩罚踩入不可行驶区域(空间约束)
    """

    optimizer = AdamW(model.parameters())

    for batch in dataloader:

        x0         = batch['trajectory_tokens']   # 干净轨迹 token 序列
        context    = batch['context']             # 多模态上下文 c
        cost_field = batch['dac_cost_field']      # 可行驶区域代价场

        # ── 计算三个损失分量 ───────────────────────────────
        L_DLM   = compute_dlm_loss(model, x0, context)
        L_SAP   = compute_autoedit_loss(model, x0, context)
        L_field = compute_field_loss(model, x0, context, cost_field)

        # ── 加权求和得到总监督损失 ─────────────────────────
        L_sup = L_DLM + lambda_SAP * L_SAP + lambda_field * L_field

        # ── 反向传播 & 更新参数 ────────────────────────────
        optimizer.zero_grad()
        L_sup.backward()
        optimizer.step()

        log(L_DLM=L_DLM, L_SAP=L_SAP, L_field=L_field, L_sup=L_sup)

    # SFT 结束后,进入 RL 微调阶段(见论文 Section 4.5)
    return model

整体结构对应论文的训练三角形,用一张图看清楚各部分的关系:

同一批干净轨迹 x0
       │
       ├─── 随机盖 assessments ──→  L_DLM   (教模型:从残缺填完整)
       │
       ├─── 纵向/横向扰动 ──→  L_SAP   (教模型:从错误改正确)
       │
       └─── 输出概率分布 ───→  L_field  (教模型:概率质量远离禁区)
                    │
                    └──── 三者加权求和 ──→ L_sup → 反向传播

有几个细节特别值得注意:

L_SAPL_DLM 的本质区别:两者用的是同一个模型,但输入性质完全不同。L_DLM 的输入是"信息随机缺失"(assessments),L_SAP 的输入是"信息系统性错误"(扰动坐标),后者才是 AutoEdit 真正的训练信号。

L_field 作用在概率层面:它不是在惩罚模型"预测了出界的 token",而是在惩罚模型"对出界位置赋予了高概率",惩罚力度随置信度和距离边界的远近双重加权,比简单的离散惩罚更平滑。

三个损失共享同一个模型参数:没有任何额外网络,SFT 结束后这个模型既能起草,又能纠错,还会绕开禁区——然后直接进入 RL 阶段做联合优化。


RL 微调阶段伪代码

# ============================================================
#  ReflectDrive-2 · RL 微调阶段伪代码
#  对应论文 Section 3.4 / 4.5
# ============================================================
#
#  核心思想:
#    SFT 之后,起草器和 AutoEdit 各自都有能力,但互不知道对方的存在。
#    RL 阶段把两者绑在一起:跑完"起草→AutoEdit"的完整 rollout,
#    只用最终轨迹的驾驶得分作为奖励,同时回传给两个阶段。
#    结果:起草器学会生成"可被 AutoEdit 改好"的草稿;
#          AutoEdit 学会做"真正提升驾驶分"的修改。
#
#  符号说明(与论文保持一致)
#    Ng       : 每帧采样的目标点数量(论文实验中 Ng=3)
#    I        : 每个目标点采样的草稿数(论文实验中 I=2)
#    G        : 总候选轨迹数 = Ng × I(论文实验中 G=6)
#    Sdraft   : 草稿阶段的 unmask 轮数
#    Sedit    : AutoEdit 阶段的修改轮数
#    S        : 总步数 = Sdraft + Sedit
#    ρg       : 第 g 条候选的完整 token 转换序列(长度 S+1)
#    τg       : 第 g 条候选的最终轨迹(连续坐标)
#    R(τg)    : 闭环规划奖励(即 PDMS 得分)
#    Ag       : 第 g 条候选的组相对优势
#    πθ       : 当前策略(即当前模型参数)
#    πθ_old   : rollout 时冻结的旧策略(用于计算重要性采样比)
#    πref     : SFT 之后冻结的参考策略(用于 KL 惩罚)
# ============================================================


import math
import copy


# ----------------------------------------------------------
# Step 1 · 单条候选的完整 Draft→AutoEdit Rollout
#          论文公式 (16)(17)
# ----------------------------------------------------------

def run_single_rollout(model, goal_token, context,
                       S_draft=3, S_edit=3):
    """
    对一个目标点执行完整的"起草→AutoEdit"rollout,
    并记录每一步的 token 转换序列(用于后续计算策略梯度)。

    Args:
        model      : 当前策略模型 πθ_old(rollout 时冻结)
        goal_token : 已选定的目标点 token(行为锚点)
        context    : 多模态上下文 c
        S_draft    : 草稿阶段 unmask 轮数
        S_edit     : AutoEdit 阶段修改轮数

    Returns:
        trajectory_history : 完整 token 状态序列
                             [x⁰, x¹, ..., x^S]  共 S+1 个快照
                             前 S_draft 步是草稿阶段,后 S_edit 步是 AutoEdit 阶段
        final_trajectory   : 最终连续坐标轨迹 τg(送入仿真器求奖励)
    """

    L = 16  # token 序列长度(8个路径点 × 2坐标)

    # ── 初始化:全部 assessments,目标点 token 固定 ────────────────
    x_current = [MASK_TOKEN_ID] * L
    x_current = set_goal_token(x_current, goal_token)  # 锁定目标点位置

    trajectory_history = [x_current.copy()]  # x⁰

    # ══════════════════════════════════════════════════════
    # 阶段 A · 草稿生成(Masked Diffusion Drafting)
    #   从全 assessments 出发,每轮并行 unmask 最有把握的 token
    # ══════════════════════════════════════════════════════
    for s in range(S_draft):

        # A-1. 模型预测每个位置的 token 概率分布
        #      输入:带 assessments 的不完整序列 + 多模态上下文
        logits = model(x_current, context)          # shape: [L, vocab_size]
        probs  = softmax(logits, dim=-1)            # πθ_old(· | x_current, c)

        # A-2. 对每个仍是 assessments 的位置,采样一个 token
        x_predicted = []
        confidence  = []
        for i in range(L):
            if x_current[i] == MASK_TOKEN_ID:
                token_id = sample_from(probs[i])    # 按概率采样(非贪心)
                conf     = probs[i][token_id]       # 该 token 的概率作为置信度
            else:
                token_id = x_current[i]             # 已确定的 token 保持不变
                conf     = 1.0
            x_predicted.append(token_id)
            confidence.append(conf)

        # A-3. 按置信度从高到低排序,本轮只 unmask 最有把握的那一批
        #      每轮 unmask 比例随步数递增(第 s 轮 unmask 约 s/S_draft 的 token)
        n_to_unmask = compute_unmask_count(s, S_draft, n_masked=count_masks(x_current))
        unmask_indices = top_k_indices(confidence, k=n_to_unmask,
                                       only_masked_positions=True)

        # A-4. 把选中位置的 assessments 替换为预测 token,其余保持 assessments
        x_next = x_current.copy()
        for idx in unmask_indices:
            x_next[idx] = x_predicted[idx]

        x_current = x_next
        trajectory_history.append(x_current.copy())   # 记录 x^(s+1)

    # 草稿完成,x_current 此时应没有 assessments(全部填满)
    x_draft = x_current.copy()   # x^(S_draft)

    # ══════════════════════════════════════════════════════
    # 阶段 B · AutoEdit 修改(Token-to-Token Rewriting)
    #   从完整草稿出发,直接 token→token 重写低置信度位置
    #   注意:不再引入 assessments,输入始终是完整的具体 token 序列
    # ══════════════════════════════════════════════════════
    for k in range(S_edit):

        # B-1. 模型对当前完整 token 序列预测替换 token
        #      输入:当前具体 token 序列(无 assessments)+ 多模态上下文
        logits      = model(x_current, context)     # shape: [L, vocab_size]
        probs       = softmax(logits, dim=-1)
        x_hat       = argmax(probs, dim=-1)         # 每个位置的最优替换 token
        confidence  = [probs[i][x_hat[i]] for i in range(L)]

        # B-2. 构造 commit mask:选出置信度低的非目标点 token
        #      论文公式 (10):x^(k+1) = m^(k) ⊙ x̂^(k+1) + (1-m^(k)) ⊙ x^(k)
        #      目标点 token 永远不进入候选(行为锚点锁定)
        commit_mask = compute_commit_mask(
            confidence    = confidence,
            x_current     = x_current,
            goal_positions = get_goal_positions(),   # 目标点 token 的位置索引
            threshold      = CONFIDENCE_THRESHOLD    # 低于此值才会被替换
        )   # shape: [L],1=替换,0=保留

        # B-3. 按 commit mask 执行原地替换
        x_next = [
            x_hat[i] if commit_mask[i] == 1 else x_current[i]
            for i in range(L)
        ]

        x_current = x_next
        trajectory_history.append(x_current.copy())   # 记录 x^(S_draft + k + 1)

    # 最终轨迹(连续坐标)送入闭环仿真器求奖励
    final_trajectory = detokenize(x_current)          # τg

    # trajectory_history 包含 S_draft + S_edit + 1 个快照
    # 论文公式 (16):ρg = (x⁰g, x¹g, ..., x^(Sdraft+Sedit)g)
    return trajectory_history, final_trajectory


# ----------------------------------------------------------
# Step 2 · 对一个场景采样 G 条候选 rollout
#          论文公式 (16)(17)(18)
# ----------------------------------------------------------

def sample_group_rollouts(model, context, N_g=3, I=2,
                          S_draft=3, S_edit=3):
    """
    对当前场景:
      1. 采样 Ng 个目标点(top-k + NMS)
      2. 每个目标点生成 I 条草稿
      3. 每条草稿跑完完整的 AutoEdit rollout
      4. 把最终轨迹送入仿真器,得到 G 个驾驶奖励

    Args:
        model   : 旧策略模型 πθ_old(rollout 期间完全冻结)
        context : 多模态上下文 c
        N_g     : 目标点采样数(论文实验值 3)
        I       : 每个目标点的草稿数(论文实验值 2)

    Returns:
        group : 长度 G = Ng×I 的列表,每项包含:
                  - history       : 完整 token 转换序列 ρg
                  - final_traj    : 最终连续坐标轨迹 τg
                  - reward        : 闭环驾驶奖励 R(τg)
                  - advantage     : 组相对优势 Ag(统一计算后填入)
    """

    G = N_g * I
    group = []

    # ── 2-a. 预测目标点后验分布,top-k 采样 + NMS ────────────
    goal_logits = model.goal_head(context)        # 目标点概率分布
    goal_tokens = top_k_nms_sample(goal_logits,
                                   k=N_g,
                                   nms_threshold=1.2)   # NMS 阈值约 1.2m

    # ── 2-b. 对每个目标点生成 I 条 rollout ───────────────────
    for goal_token in goal_tokens:          # 外层:Ng 个目标点
        for _ in range(I):                  # 内层:每个目标点 I 条草稿

            history, final_traj = run_single_rollout(
                model, goal_token, context, S_draft, S_edit
            )

            # 送入闭环仿真器(NAVSIM)计算 PDMS 奖励
            # PDMS = 加权综合(NC, DAC, TTC, 舒适性, EP)
            reward = closed_loop_simulator.score(final_traj)   # R(τg)

            group.append({
                'history'    : history,      # ρg:[x⁰, x¹, ..., x^S]
                'final_traj' : final_traj,   # τg
                'reward'     : reward,       # R(τg)
                'goal_token' : goal_token,
            })

    # ── 2-c. 计算组相对优势(Group-Relative Advantage)────────
    #         论文公式 (18):Ag = R(τg) - 1/G · Σj R(τj)
    #
    #         直觉:比组内平均分高 → 正优势(这条路走对了,强化它)
    #               比组内平均分低 → 负优势(这条路走错了,抑制它)
    mean_reward = sum(item['reward'] for item in group) / G

    for item in group:
        item['advantage'] = item['reward'] - mean_reward  # Ag

    return group   # 长度 G 的候选列表,每项含 ρg, τg, R(τg), Ag


# ----------------------------------------------------------
# Step 3 · 计算单条 rollout 的策略梯度损失
#          论文公式 (2)(19)
# ----------------------------------------------------------

def compute_pg_loss_single(model, model_old, item, epsilon=0.2):
    """
    对一条候选 rollout 计算 PPO-clip 风格的策略梯度损失。

    核心设计:
      · 只有"在这一步真正发生变化"的 token 才获得梯度信号
        (草稿阶段:从 assessments 变成具体 token 的位置)
        (AutoEdit 阶段:被 commit mask 选中并替换的位置)
      · 使用重要性采样比 r = πθ / πθ_old 做 off-policy 修正
      · clip 截断防止更新步幅过大(PPO 技巧)

    论文公式 (2) 中的指示函数:
      δ^s_{g,p} = 1{ x^(s+1)_{g,p} ≠ x^s_{g,p} }
      只有 token 在第 s 步发生了变化,才把这步的梯度纳入损失。

    Args:
        model     : 当前策略 πθ(需要更新)
        model_old : 旧策略 πθ_old(rollout 时冻结,用于计算比值)
        item      : 单条 rollout 信息(含 history, advantage)
        epsilon   : PPO clip 系数

    Returns:
        loss_pg : 该条 rollout 的策略梯度损失(标量)
    """

    history   = item['history']    # [x⁰, x¹, ..., x^S]  共 S+1 个快照
    Ag        = item['advantage']  # 组相对优势(标量)
    context   = item['context']
    S         = len(history) - 1   # 总步数 = S_draft + S_edit
    L         = len(history[0])    # token 序列长度 = 16

    loss_pg = 0.0
    n_credited = 0   # 记录实际获得梯度的 (步, 位置) 对数量

    for s in range(S):

        x_s      = history[s]       # 第 s 步的 token 状态(模型输入)
        x_s_next = history[s + 1]   # 第 s+1 步的 token 状态(目标)

        # 3-a. 用当前策略和旧策略分别计算每个位置的 token 概率
        logits_new = model(x_s, context)          # πθ 的输出
        logits_old = model_old(x_s, context)      # πθ_old 的输出(不求梯度)

        probs_new = softmax(logits_new, dim=-1)   # shape: [L, vocab_size]
        probs_old = softmax(logits_old, dim=-1)   # shape: [L, vocab_size]

        for p in range(L):

            # 3-b. 指示函数:只有这个位置在这一步发生了变化,才纳入梯度
            #      论文公式 (19):δ^s_{g,p} = 1{ x^(s+1)_{g,p} ≠ x^s_{g,p} }
            #
            #      草稿阶段:assessments → 具体 token,发生变化 → δ=1
            #      AutoEdit:被替换的 token 位置,发生变化 → δ=1
            #      两个阶段没变化的位置:δ=0,跳过,不贡献梯度
            token_changed = (x_s_next[p] != x_s[p])
            if not token_changed:
                continue

            target_token = x_s_next[p]   # 这一步实际写入的 token

            # 3-c. 计算重要性采样比 r^s_{g,p} = πθ / πθ_old
            #      分子:当前策略对 target_token 的概率
            #      分母:旧策略对 target_token 的概率(采样时的概率)
            prob_new = probs_new[p][target_token]
            prob_old = probs_old[p][target_token]
            ratio    = prob_new / (prob_old + 1e-8)   # r^s_{g,p}

            # 3-d. PPO-clip:截断过大的更新步幅
            #      论文公式 (2):min( r·Ag,  clip(r, 1-ε, 1+ε)·Ag )
            clipped_ratio = clip(ratio, 1 - epsilon, 1 + epsilon)
            objective     = min(ratio * Ag, clipped_ratio * Ag)

            loss_pg += -objective   # 最大化目标 = 最小化负目标
            n_credited += 1

    # 对所有 (步, 位置) 取平均
    if n_credited > 0:
        loss_pg /= n_credited

    return loss_pg


# ----------------------------------------------------------
# Step 4 · KL 散度惩罚项
#          论文公式 (2) 中的 λ_KL · DKL(πθ ‖ πref)
# ----------------------------------------------------------

def compute_kl_penalty(model, model_ref, x_masked, context):
    """
    防止 RL 微调跑偏太远,在 SFT 基础上保持合理的分布。

    πref:SFT 结束后冻结的参考策略(不随 RL 更新)
    πθ  :当前正在更新的策略

    对一个随机掩码输入,计算两者输出分布的 KL 散度。

    Args:
        model     : 当前策略 πθ
        model_ref : SFT 后冻结的参考策略 πref
        x_masked  : 随机掩码后的 token 序列
        context   : 多模态上下文

    Returns:
        kl : KL 散度(标量)
    """

    probs_theta = softmax(model(x_masked, context),     dim=-1)  # πθ
    probs_ref   = softmax(model_ref(x_masked, context), dim=-1)  # πref(无梯度)

    # KL(πθ ‖ πref) = Σ πθ · log(πθ / πref)
    kl = 0.0
    L  = probs_theta.shape[0]

    for i in range(L):
        for v in range(VOCAB_SIZE):
            p = probs_theta[i][v]
            q = probs_ref[i][v]
            if p > 1e-9:
                kl += p * math.log(p / (q + 1e-9))

    return kl / L   # 对位置取平均


# ----------------------------------------------------------
# Step 5 · RL 微调主循环
#          论文公式 (2)
# ----------------------------------------------------------

def rl_finetuning(model_sft, dataloader,
                  N_g=3, I=2, S_draft=3, S_edit=3,
                  epsilon=0.2, lambda_kl=0.01,
                  n_epochs=1):
    """
    在 SFT 模型基础上做 RL 微调(Reinforcement Fine-Tuning, RFT)。

    Args:
        model_sft : SFT 阶段训练好的模型(作为起点和参考策略)
        dataloader: 驾驶场景数据集
        N_g       : 每场景目标点数
        I         : 每目标点草稿数
        S_draft   : 草稿阶段步数
        S_edit    : AutoEdit 阶段步数
        epsilon   : PPO clip 系数
        lambda_kl : KL 惩罚权重
        n_epochs  : RL 训练轮数
    """

    # ── 初始化 ──────────────────────────────────────────────
    model     = copy.deepcopy(model_sft)   # πθ:当前策略(持续更新)
    model_ref = copy.deepcopy(model_sft)   # πref:SFT 参考策略(永久冻结)
    freeze(model_ref)

    optimizer = AdamW(model.parameters(), lr=1e-6)

    for epoch in range(n_epochs):
        for scene in dataloader:

            context = scene['context']   # 多模态上下文 c

            # ══════════════════════════════════════════════
            # Phase 1 · Rollout(用旧策略采样,不求梯度)
            # ══════════════════════════════════════════════

            # 冻结当前策略作为 πθ_old,用于 rollout 和计算重要性采样比
            model_old = copy.deepcopy(model)
            freeze(model_old)

            # 采样 G = Ng × I 条完整 Draft→AutoEdit rollout
            # 并计算每条轨迹的驾驶奖励和组相对优势
            with no_grad():
                group = sample_group_rollouts(
                    model_old, context, N_g, I, S_draft, S_edit
                )
            # group 中每项已包含:history(ρg), final_traj(τg), reward(R), advantage(Ag)

            # ══════════════════════════════════════════════
            # Phase 2 · 计算总损失(需要梯度)
            # ══════════════════════════════════════════════

            total_loss = 0.0

            # 2-a. 对 G 条 rollout 分别计算策略梯度损失,取平均
            #      论文公式 (2) 的主项
            loss_pg_total = 0.0
            for item in group:
                item['context'] = context
                loss_pg_total  += compute_pg_loss_single(
                    model, model_old, item, epsilon
                )
            loss_pg = loss_pg_total / len(group)   # 对 G 条取平均

            # 2-b. KL 惩罚:防止策略偏离 SFT 太远
            x_sample = sample_masked_sequence(scene['trajectory_tokens'])
            loss_kl  = compute_kl_penalty(model, model_ref, x_sample, context)

            # 2-c. 合并
            #      L(θ) = L_PG + λ_KL · DKL(πθ ‖ πref)
            total_loss = loss_pg + lambda_kl * loss_kl

            # ══════════════════════════════════════════════
            # Phase 3 · 更新参数
            # ══════════════════════════════════════════════

            optimizer.zero_grad()
            total_loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            log(
                epoch       = epoch,
                loss_pg     = loss_pg,
                loss_kl     = loss_kl,
                total_loss  = total_loss,
                mean_reward = sum(item['reward'] for item in group) / len(group),
                autoedit_gain = compute_autoedit_gain(group),
                # autoedit_gain = 有 AutoEdit 的平均分 - 无 AutoEdit 的平均分
                # RL 训练成功的标志:这个值从 ~0.3 增长到 ~1.9
            )

    return model   # RL 微调完成,起草器和 AutoEdit 已联合优化

整个 RL 阶段分三个 Phase 循环执行,用一张图总结结构:

每个场景迭代:

Phase 1 · Rollout(冻结旧策略,不求梯度)
  ┌─────────────────────────────────────────────┐
  │  目标点 NMS 采样(Ng=3)                      │
  │    × 每个目标点 I=2 条草稿                    │
  │    = G=6 条完整 Draft→AutoEdit rollout        │
  │  每条送入 NAVSIM 得 R(τg)                     │
  │  统一计算组相对优势 Ag = R(τg) - mean(R)      │
  └─────────────────────────────────────────────┘
           ↓
Phase 2 · 计算损失(当前策略,需要梯度)
  ┌─────────────────────────────────────────────┐
  │  对 G 条 rollout,逐步逐位置:               │
  │    只看 token 真正改变的位置(指示函数 δ)   │
  │    计算 r = πθ/πθ_old                        │
  │    loss = -min(r·Ag, clip(r)·Ag)            │
  │  + λ_KL · DKL(πθ ‖ πref)                   │
  └─────────────────────────────────────────────┘
           ↓
Phase 3 · 反向传播更新

有三个设计决策最值得注意:

指示函数 δ 是灵魂。草稿阶段里从 assessments 变成具体 token 的位置、AutoEdit 阶段里被替换的位置,两类"变化"用同一个 δ 函数统一处理,终局奖励自然流入两个阶段,不需要任何人为分拆。

两套冻结模型各司其职πθ_old 在每轮迭代开始时复制当前模型,用于 rollout 和计算重要性采样比(保证 off-policy 修正的正确性);πref 是 SFT 结束后一次性冻结,全程不动,只用来算 KL 惩罚防止策略跑偏太远。

奖励只在终点打分。仿真器只看最终修改后的轨迹,不会给草稿中间过程评分。这逼着起草器主动生成"留有改善余地"的草稿,而不是一开始就出最优解——正是这个机制让 AutoEdit 增益从 +0.3 跳到 +1.9。


Action-Expert FFN

背景:Transformer 里的 FFN 是什么

标准 Transformer 的每一层都有两个子模块:

输入 token
  → 自注意力(Attention):让 token 互相"看"彼此
  → FFN(前馈网络):对每个 token 独立做非线性变换
  → 输出 token

FFN 是 Transformer 里参数量最大、计算量最重的部分,约占整个模型计算量的 2/3。

核心问题:轨迹 token 需要这么大的 FFN 吗?

模型里同时存在两类 token,它们的性质天差地别:

语言/视觉 token 轨迹 token
词表大小 ~50,000 个词 几十个坐标 bin
语义复杂度 极高(理解场景、推理意图) 极低(就是一个坐标数字)
需要的表达能力
占序列长度 大部分 固定 16 个

用一个 4096 维的 FFN 去处理"第 47 号坐标格子"这种极其简单的信息,就像用核弹炸蚊子——严重过剩。

Action-Expert FFN 的解决方案

给轨迹 token 单独配一套更小的 FFN,语言/视觉 token 继续用原来的大 FFN:

每一层 Transformer:

  语言/视觉 token ──→ 主 FFN(d_ffn = 4096)──→ 继续
                              ↑ 按 token 类型路由
  轨迹 token      ──→ Action FFN(d_ffn = 1024)──→ 继续

两套 FFN 的参数完全独立,在同一层里并行存在,根据 token 的类型决定走哪条路。

关键的设计洞察

Attention 不分路由,FFN 才分。这很重要——轨迹 token 需要通过 Attention"看到"整个场景(障碍物在哪、车道线在哪)才能生成正确的坐标,这部分绝不能缩减。FFN 只是事后对每个 token 独立做变换,轨迹 token 在这一步根本不需要理解语义,1024 维完全够用。

路由是零成本的。轨迹 token 始终在序列末尾的固定位置,不需要任何路由网络或运行时判断,直接切片就完成了分流,没有额外开销。

精度反升是最有趣的发现。大 FFN 让语言梯度和轨迹梯度在同一组参数里相互干扰;拆分之后 Action FFN 的梯度来源纯粹,加上容量受限带来的隐式正则化,模型反而学到了更干净的坐标变换规律。

Action-Expert FFN 实现伪代码

# ============================================================
#  Action-Expert FFN 实现伪代码
#  对应论文 Section 5(推理优化部分)
# ============================================================
#
#  核心思想:
#    轨迹 token 的语义空间极小(只是坐标格子编号),
#    不需要和语言/视觉 token 一样大的 FFN。
#    用一个更小的专用 FFN 处理轨迹 token,
#    在不损失(甚至提升)精度的前提下大幅降低计算量。
#
#  论文数值:
#    主 FFN   d_ffn = 4096(论文实验中约为 d_model 的 2 倍)
#    Action FFN d_ffn = 1024(压缩到 1/4)
#    加速比:2.6×
#    meanSADE:不降反升(说明大 FFN 对轨迹 token 反而是过拟合)
# ============================================================

import torch
import torch.nn as nn


# ----------------------------------------------------------
#  1. 标准 FFN(用于语言/视觉 token)
# ----------------------------------------------------------

class StandardFFN(nn.Module):
    """
    Transformer 主干的标准前馈网络。
    处理语言 token、视觉 token、导航指令 token、自车状态 token。

    维度:d_model → d_ffn(大) → d_model
    """

    def __init__(self, d_model=2048, d_ffn=4096, dropout=0.0):
        super().__init__()
        self.w1      = nn.Linear(d_model, d_ffn,   bias=False)
        self.w2      = nn.Linear(d_ffn,   d_model, bias=False)
        self.act     = nn.SiLU()   # 论文用 SwiGLU 变体,这里简化为 SiLU
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: [batch, seq_len, d_model]
        return self.dropout(self.w2(self.act(self.w1(x))))


# ----------------------------------------------------------
#  2. Action-Expert FFN(专用于轨迹 token)
# ----------------------------------------------------------

class ActionExpertFFN(nn.Module):
    """
    轨迹 token 专用的小型前馈网络。
    只处理 16 个轨迹 token(8 个路径点 × 2 坐标)。

    维度:d_model → d_action_ffn(小) → d_model
    d_action_ffn = 1024(是标准 FFN 的 1/4)

    为什么更小也够用?
      轨迹 token 的 "语义" 极其简单:就是一个坐标格子的编号。
      它不需要理解"转弯"、"让行"这类高层语义(那是 Attention 的工作)。
      FFN 只需要做一个低维的坐标→隐空间的映射,1024 维绰绰有余。
    """

    def __init__(self, d_model=2048, d_action_ffn=1024, dropout=0.0):
        super().__init__()
        self.w1      = nn.Linear(d_model, d_action_ffn, bias=False)
        self.w2      = nn.Linear(d_action_ffn, d_model, bias=False)
        self.act     = nn.SiLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: [batch, n_action_tokens, d_model]  n_action_tokens ≤ 16
        return self.dropout(self.w2(self.act(self.w1(x))))


# ----------------------------------------------------------
#  3. 混合 FFN 层:按 token 类型路由
# ----------------------------------------------------------

class MixedExpertFFNLayer(nn.Module):
    """
    一个完整的 Transformer FFN 子层,内含两套 FFN:
      - StandardFFN    → 处理所有非轨迹 token
      - ActionExpertFFN→ 处理所有轨迹 token

    路由依据:token 位置索引(轨迹 token 始终在序列末尾固定位置)。

    序列结构(论文 Section 4.1):
      [视觉 token (V)] [导航 token (N)] [自车状态 token (E)]
      [目标点 token (G)] [轨迹 token (A)]
                               这 16 个位置走 Action FFN
    """

    def __init__(self,
                 d_model=2048,
                 d_ffn=4096,
                 d_action_ffn=1024,
                 n_action_tokens=16):
        super().__init__()

        self.standard_ffn     = StandardFFN(d_model, d_ffn)
        self.action_expert_ffn = ActionExpertFFN(d_model, d_action_ffn)
        self.n_action_tokens  = n_action_tokens

        # 轨迹 token 始终在序列末尾,位置固定
        # 这使得路由无需任何运行时判断,直接切片即可
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, action_token_mask=None):
        """
        Args:
            x                : [batch, seq_len, d_model]
                               完整序列(视觉+语言+目标点+轨迹 token 混在一起)
            action_token_mask: [batch, seq_len] bool tensor
                               True 表示该位置是轨迹 token
                               如果为 None,则默认末尾 n_action_tokens 个是轨迹 token

        Returns:
            out: [batch, seq_len, d_model]
        """

        batch, seq_len, d = x.shape
        out = x.clone()   # 输出在原位修改

        # ── 构造路由 mask ─────────────────────────────────────
        if action_token_mask is None:
            # 默认:序列末尾 n_action_tokens 个位置是轨迹 token
            action_token_mask = torch.zeros(batch, seq_len, dtype=torch.bool)
            action_token_mask[:, -self.n_action_tokens:] = True

        context_mask = ~action_token_mask   # 非轨迹 token 的 mask

        # ── 路由 1:非轨迹 token → StandardFFN ───────────────
        # 把所有非轨迹位置的 token 聚合成一个 batch 送入大 FFN
        # 用 mask 索引实现,避免 for 循环
        if context_mask.any():
            x_context = x[context_mask]        # [n_context, d_model]
            x_context = self.layer_norm(x_context)
            out_context = self.standard_ffn(x_context)   # [n_context, d_model]
            out[context_mask] = out_context

        # ── 路由 2:轨迹 token → ActionExpertFFN ─────────────
        if action_token_mask.any():
            x_action = x[action_token_mask]    # [n_action, d_model]
                                               # n_action = batch × 16(或更少)
            x_action = self.layer_norm(x_action)
            out_action = self.action_expert_ffn(x_action)  # [n_action, d_model]
            out[action_token_mask] = out_action

        return out


# ----------------------------------------------------------
#  4. 将其嵌入完整 Transformer Block
# ----------------------------------------------------------

class TransformerBlockWithActionExpert(nn.Module):
    """
    完整的 Transformer 层,Attention 部分不变,
    FFN 部分替换为 MixedExpertFFNLayer。

    注意:Attention 对所有 token 统一计算(轨迹 token 需要
    "看"场景上下文才能生成正确坐标),只有 FFN 做了分路由。
    """

    def __init__(self, d_model=2048, n_heads=16,
                 d_ffn=4096, d_action_ffn=1024):
        super().__init__()

        # Attention 部分:所有 token 统一处理(双向注意力)
        self.attn     = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.norm1    = nn.LayerNorm(d_model)

        # FFN 部分:按 token 类型路由
        self.ffn      = MixedExpertFFNLayer(d_model, d_ffn, d_action_ffn)
        self.norm2    = nn.LayerNorm(d_model)

    def forward(self, x, attn_mask=None, action_token_mask=None):

        # ── Attention 子层(Pre-Norm + 残差)────────────────
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, attn_mask=attn_mask)
        x = x + attn_out                   # 残差连接

        # ── 混合 FFN 子层(Post-Norm + 残差)────────────────
        ffn_out = self.ffn(self.norm2(x), action_token_mask)
        x = x + ffn_out                    # 残差连接

        return x


# ----------------------------------------------------------
#  5. 参数量 & 计算量对比(以单层 FFN 为例)
# ----------------------------------------------------------

def compare_params_and_flops(d_model=2048, d_ffn=4096, d_action_ffn=1024,
                              n_context_tokens=512, n_action_tokens=16):
    """
    对比标准方案(所有 token 用大 FFN)
    vs   Action-Expert 方案(轨迹 token 用小 FFN)
    的参数量和 FLOPs 差异。
    """

    # ── 参数量 ────────────────────────────────────────────
    # 标准 FFN:W1(d_model→d_ffn) + W2(d_ffn→d_model)
    params_standard     = 2 * d_model * d_ffn          # 2 × 2048 × 4096

    # Action-Expert 方案:大 FFN + 小 Action FFN
    params_action_expert = (2 * d_model * d_ffn          # 大 FFN(不变)
                           + 2 * d_model * d_action_ffn)  # 小 Action FFN(新增)

    # 新增参数量很少:2 × 2048 × 1024 = 4M,相对于主干几乎可忽略

    # ── FLOPs(以 batch_size=1 为例)────────────────────────
    n_total = n_context_tokens + n_action_tokens

    # 标准方案:所有 token 都走大 FFN
    flops_standard = n_total * 2 * d_model * d_ffn * 2
    #                ↑token数  ↑两层线性  ↑每次乘加2op

    # Action-Expert 方案:
    flops_context = n_context_tokens * 2 * d_model * d_ffn   * 2
    flops_action  = n_action_tokens  * 2 * d_model * d_action_ffn * 2
    flops_expert  = flops_context + flops_action

    speedup = flops_standard / flops_expert

    print(f"标准方案 参数量:{params_standard/1e6:.1f}M")
    print(f"Expert 方案 参数量:{params_action_expert/1e6:.1f}M(+{(params_action_expert-params_standard)/1e6:.1f}M)")
    print()
    print(f"标准方案 FLOPs:{flops_standard/1e9:.2f}G")
    print(f"Expert 方案 FLOPs:{flops_expert/1e9:.2f}G")
    print(f"理论加速比:{speedup:.2f}×")
    print()
    print("直觉:轨迹 token 只占序列的 16/(512+16) ≈ 3%,")
    print("      但 Action FFN 是大 FFN 的 1/4,")
    print("      所以整体 FFN FLOPs 节省约 3% × 75% ≈ 2.25%。")
    print("      论文报告的 2.6× 是针对 Action FFN 模块本身的加速,")
    print("      不是全模型加速。")

    # 实际上 2.6× 指的是 Action FFN 这个操作本身的延迟下降:
    # 原来每个 token 都调用大 FFN kernel,现在轨迹 token 调用小 FFN kernel
    # kernel launch + 矩阵乘法的延迟同时下降 → 2.6×


# ----------------------------------------------------------
#  6. 为什么精度不降反升?
# ----------------------------------------------------------
#
#  论文里提到一个反直觉的现象:
#    Action FFN 从 d_ffn=4096 压缩到 d_action_ffn=1024 之后,
#    meanSADE(平均 SAD 误差,越小越好)反而变好了。
#
#  可能的解释:
#
#  ① 过参数化导致"记忆"而非"泛化"
#     大 FFN 有足够的容量把训练集里每条轨迹"背下来",
#     小 FFN 容量受限,被迫学习更通用的坐标变换规律。
#
#  ② 正则化效应
#     小 FFN 相当于对轨迹 token 的表示施加了隐式的信息瓶颈,
#     逼着模型把高层语义(该往哪走)完全依赖 Attention 解决,
#     而 FFN 只做纯粹的坐标数值变换。
#     两者分工更清晰,反而更健康。
#
#  ③ 梯度干净
#     大 FFN 里,语言和轨迹 token 共享参数,梯度相互干扰。
#     分离之后,Action FFN 的梯度只来自轨迹相关的损失,
#     更新方向更纯粹。

用一张图把整个层的结构说清楚:

Transformer 某一层的输入序列(共 ~528 个 token):
┌──────────────────────────────────┬────────────────┐
│  视觉/导航/状态/目标点 token      │  轨迹 token     │
│  ~512 个,语义复杂                │  固定 16 个     │
└──────────────────────────────────┴────────────────┘
              ↓ 自注意力(所有 token 统一计算,不分路由)
              ↓
     ┌────────┴────────┐
     │    按类型路由    │
     ↓                 ↓
 StandardFFN      ActionExpertFFN
 d_ffn = 4096     d_ffn = 1024(1/4 大小)
     ↓                 ↓
     └────────┬────────┘
              ↓ 拼回原位,形状不变

KV 缓存的合并重写

第一层:普通 KV 缓存是什么

标准自回归语言模型(GPT 系列)每次只生成一个 token,是单向因果注意力

生成第 4 个 token 时,它只能看 token 1、2、3
所以 token 1、2、3 的 K 和 V 在之前已经算过了,直接缓存复用即可

这就是 KV 缓存:已经算过的 K/V 不重复算,每步只算新 token 的 K/V。

第二层:掩码扩散为什么破坏了 KV 缓存

掩码扩散用的是双向注意力——每个 token 能看到序列里所有其他 token。而且每一步都有若干个 assessments 变成具体 token,序列在变化:

步骤 0: assessments   assessments   assessments   assessments   assessments   assessments
步骤 1: assessments  x₂    assessments  x₄    assessments   assessments   ← 位置 2、4 被填入
步骤 2: x₁    x₂    x₃     x₄    assessments   assessments   ← 位置 1、3 被填入

问题来了:步骤 1 中位置 2 从 assessments 变成了 x₂,这个变化会影响所有其他位置对它的注意力结果,理论上其他位置的 K/V 都要重算。

朴素做法:每步都把整个序列的 K/V 从头计算一遍。这就是为什么表格里"无优化"的延迟是 14.7ms。

第三层:合并重写的关键洞察

虽然双向注意力让 token 互相影响,但有一个重要的局部性质:

某个 token 在第 l 层的 K 和 V,只依赖它自己在第 l-1 层的输出,不直接依赖同层其他 token 的值。

第 l 层:
  K_i^l = W_K · h_i^(l-1)   ← 只取决于 token i 自己的上层输出
  V_i^l = W_V · h_i^(l-1)   ← 同上

合并重写的近似假设:在同一步里,只有少数位置发生变化,大部分 token 的 K/V 变化很小,可以直接复用上一步的缓存,只对真正发生变化的位置重新计算 K/V 并写回缓存。

这是一个以极小精度损失换取大幅速度提升的近似,论文验证精度影响可忽略不计。

具体实现示意

草稿阶段第 s 步,位置 {3, 7, 12} 刚被 unmask:

旧 KV 缓存(步骤 s-1):
  位置  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15
  K/V  [✓] [✓] [✓] [✗] [✓] [✓] [✓] [✗] [✓] [✓] [✓] [✓] [✗] [✓] [✓] [✓]
                      ↑               ↑               ↑
                   刚变化            刚变化           刚变化

合并重写操作:
  位置 3、7、12  → 重新计算 K/V,写回缓存   (3 个位置的计算量)
  其余 13 个位置 → 直接复用,零计算

新 KV 缓存(步骤 s):
  位置  0   1   2   3'  4   5   6   7'  8   9  10  11  12' 13  14  15
  K/V  [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓] [✓]

KV 缓存合并重写伪代码

# ============================================================
#  KV 缓存回退 + 合并重写 伪代码
#  对应论文 Section 5(推理优化部分)
# ============================================================
#
#  问题背景:
#    掩码扩散用双向注意力,每步有若干 token 从 assessments 变成具体值。
#    朴素做法:每步把全序列 K/V 从头算一遍 → 14.7ms/step
#    合并重写:只重算发生变化的位置 → 11.5ms/step,加速 1.28×
#
#  核心近似:
#    当步骤 s 中只有少数位置 Δs 发生变化时,
#    其他位置的 K/V 在 s 和 s+1 步之间变化很小,
#    可以直接复用上一步的缓存值。
# ============================================================

import torch


# ----------------------------------------------------------
#  数据结构:多层 KV 缓存
# ----------------------------------------------------------

class KVCache:
    """
    存储所有 Transformer 层的 K 和 V 张量。

    shape: [n_layers, batch, n_heads, seq_len, head_dim]

    同时维护一个"版本戳",记录每个位置的 token 在哪一步最后更新,
    用于 AutoEdit 的回退操作。
    """

    def __init__(self, n_layers, batch, n_heads, seq_len, head_dim):
        self.K = torch.zeros(n_layers, batch, n_heads, seq_len, head_dim)
        self.V = torch.zeros(n_layers, batch, n_heads, seq_len, head_dim)

        # 版本戳:position_version[i] = 该位置的 K/V 在第几步被算出来
        self.position_version = torch.full((seq_len,), -1, dtype=torch.long)

        # 快照栈:用于 AutoEdit 阶段的回退(rollback)
        # 每次"保存快照"就往栈里压一帧,回退就弹出
        self._snapshots = []

    def update_positions(self, changed_positions, new_K, new_V, step):
        """
        合并重写:只更新发生变化的位置,其余保持不变。

        Args:
            changed_positions : List[int],本步发生变化的 token 位置索引
            new_K             : [n_layers, batch, n_heads, len(changed), head_dim]
            new_V             : [n_layers, batch, n_heads, len(changed), head_dim]
            step              : 当前步骤编号
        """
        for k, pos in enumerate(changed_positions):
            self.K[:, :, :, pos, :] = new_K[:, :, :, k, :]
            self.V[:, :, :, pos, :] = new_V[:, :, :, k, :]
            self.position_version[pos] = step   # 记录版本

    def save_snapshot(self):
        """保存当前完整 KV 状态(用于 AutoEdit 回退)。"""
        self._snapshots.append({
            'K': self.K.clone(),
            'V': self.V.clone(),
            'version': self.position_version.clone()
        })

    def rollback(self):
        """
        回退到上一个快照。
        在 AutoEdit 阶段,如果当前修改使得奖励变差,
        可以撤销这一轮的 K/V 更新,重新选择要修改的位置。
        """
        assert len(self._snapshots) > 0, "没有可回退的快照"
        snap = self._snapshots.pop()
        self.K                 = snap['K']
        self.V                 = snap['V']
        self.position_version  = snap['version']


# ----------------------------------------------------------
#  核心函数:计算变化位置的新 K/V
# ----------------------------------------------------------

def recompute_kv_for_changed_positions(model, token_sequence,
                                       changed_positions, kv_cache):
    """
    只对 changed_positions 上的 token 重新计算 K 和 V,
    其余位置直接复用 kv_cache 中已有的值。

    这就是"合并重写"的实质:
      - 旧缓存(不变位置):直接读取,零计算
      - 新计算(变化位置):只算这几个 token,写回缓存

    Args:
        model             : Transformer 模型
        token_sequence    : 当前完整 token 序列 [batch, seq_len]
        changed_positions : 本步发生变化的位置索引列表
        kv_cache          : 当前 KV 缓存对象

    Returns:
        updated_kv_cache  : 更新后的 KV 缓存
    """

    if len(changed_positions) == 0:
        return kv_cache   # 没有变化,直接返回

    # ── Step 1:只取出变化位置的 token embedding ──────────
    #    不需要对整个序列做 embedding,只做几个 token
    changed_token_ids = token_sequence[:, changed_positions]  # [batch, n_changed]
    changed_embeddings = model.embed(changed_token_ids)        # [batch, n_changed, d_model]

    # ── Step 2:逐层计算这些位置的新 K 和 V ───────────────
    #
    #    关键近似在这里:
    #    每层的输入 h_i^(l-1) 理论上依赖上层所有 token 的注意力结果。
    #    近似处理:对变化位置,用当前层已有的完整 K/V 做注意力,
    #              得到这些位置的新隐状态,再算新的 K/V。
    #    对不变位置:直接复用缓存,不做任何计算。
    #
    new_K_all_layers = []
    new_V_all_layers = []

    h = changed_embeddings   # [batch, n_changed, d_model]  初始为 embedding

    for layer_idx, layer in enumerate(model.layers):

        # 2-a. 用当前层已缓存的完整 K/V 做 cross-attention:
        #      变化位置的 query 去 attend 整个序列(包括不变位置的缓存 K/V)
        #      这样变化位置的隐状态就能看到完整的上下文
        K_full = kv_cache.K[layer_idx]   # [batch, heads, seq_len, head_dim]
        V_full = kv_cache.V[layer_idx]   # [batch, heads, seq_len, head_dim]

        # 只对变化位置做 attention(query 来自变化位置,key/value 来自完整缓存)
        Q_changed = layer.W_Q(h)   # [batch, n_changed, d_model]
        h = layer.cross_attend(Q_changed, K_full, V_full)  # [batch, n_changed, d_model]
        h = layer.ffn(h)                                   # [batch, n_changed, d_model]

        # 2-b. 用更新后的隐状态计算这些位置的新 K 和 V
        new_K = layer.W_K(h)   # [batch, n_changed, d_model] → reshape to [batch, heads, n_changed, head_dim]
        new_V = layer.W_V(h)

        new_K_all_layers.append(new_K)
        new_V_all_layers.append(new_V)

    # ── Step 3:合并写回缓存(只写变化的位置)─────────────
    #    "合并":新算的几个位置 + 原来缓存的其余位置 = 完整 KV 缓存
    new_K_tensor = torch.stack(new_K_all_layers)   # [n_layers, batch, heads, n_changed, head_dim]
    new_V_tensor = torch.stack(new_V_all_layers)

    kv_cache.update_positions(changed_positions, new_K_tensor, new_V_tensor,
                              step=current_step)

    return kv_cache


# ----------------------------------------------------------
#  草稿阶段:带合并重写的掩码扩散解码
# ----------------------------------------------------------

def draft_with_kv_merge_rewrite(model, context_tokens, S_draft=3):
    """
    草稿阶段的推理循环。
    每步只 unmask 少数 token,用合并重写更新 KV 缓存。

    Args:
        context_tokens : 场景上下文 token(视觉/导航/状态)[batch, n_ctx]
        S_draft        : 草稿阶段总步数

    Returns:
        draft_tokens   : 填满的轨迹 token 序列 [batch, L]
        kv_cache       : 最终 KV 缓存(供 AutoEdit 阶段复用)
    """

    L = 16   # 轨迹 token 数量
    batch = context_tokens.shape[0]

    # ── 初始化:轨迹部分全 assessments,和上下文拼在一起 ────────
    traj_tokens = torch.full((batch, L), MASK_TOKEN_ID)
    full_sequence = torch.cat([context_tokens, traj_tokens], dim=1)
    # full_sequence: [batch, n_ctx + L]

    # ── 计算上下文前缀的 KV 缓存(只算一次,整个推理过程复用)
    #    上下文 token 用因果注意力,所以这部分可以完美缓存
    kv_cache = compute_prefix_kv(model, context_tokens)   # 见论文"共享前缀 KV 缓存"

    # ── 初始化轨迹 token 的 KV(全 assessments 的 embedding)─────
    mask_embedding = model.embed(traj_tokens)   # [batch, L, d_model]
    initial_K, initial_V = model.compute_kv(mask_embedding)
    kv_cache.update_positions(list(range(L)), initial_K, initial_V, step=0)

    # ── 草稿循环 ─────────────────────────────────────────────
    unmasked_positions = set()   # 已经填好的位置

    for s in range(S_draft):

        # 用当前完整 KV 缓存做前向,得到所有位置的 logits
        # (只对仍是 assessments 的位置采样,已填好的位置直接保留)
        logits = model.forward_with_kv(full_sequence, kv_cache)
        # logits: [batch, n_ctx + L, vocab_size]

        traj_logits = logits[:, -L:, :]   # 只取轨迹部分的 logits
        probs = softmax(traj_logits, dim=-1)

        # 选出本步要填入的位置(最高置信度的那批 assessments)
        n_to_unmask = compute_unmask_count(s, S_draft, n_masked=(L - len(unmasked_positions)))
        still_masked = [i for i in range(L) if i not in unmasked_positions]

        confidences = {i: probs[0, i, :].max().item() for i in still_masked}
        newly_unmasked = sorted(still_masked,
                                key=lambda i: -confidences[i])[:n_to_unmask]

        # 采样并填入
        for pos in newly_unmasked:
            sampled_token = sample_from(probs[0, pos])
            full_sequence[0, -L + pos] = sampled_token
            unmasked_positions.add(pos)

        # ════════════════════════════════════════════════
        #  合并重写:只对新填入的位置重算 K/V
        #  之前已 unmask 的位置 + 仍是 assessments 的位置 → 直接复用缓存
        # ════════════════════════════════════════════════
        kv_cache = recompute_kv_for_changed_positions(
            model          = model,
            token_sequence = full_sequence,
            changed_positions = [n_ctx + pos for pos in newly_unmasked],
            #                    ↑ 转换为全序列索引(轨迹在末尾)
            kv_cache       = kv_cache
        )
        # 本步只重算了 n_to_unmask 个位置的 K/V,其余全部复用!

    draft_tokens = full_sequence[:, -L:]   # 提取轨迹 token
    return draft_tokens, kv_cache


# ----------------------------------------------------------
#  AutoEdit 阶段:带回退 + 合并重写的轨迹修正
# ----------------------------------------------------------

def autoedit_with_rollback(model, draft_tokens, context_tokens,
                           kv_cache, S_edit=3):
    """
    AutoEdit 阶段的推理循环。
    比草稿阶段多了一个"回退"能力:
    如果某次修改后置信度反而更低,可以撤销这次修改。

    Args:
        draft_tokens   : 草稿轨迹 token [batch, L]
        context_tokens : 场景上下文 token [batch, n_ctx]
        kv_cache       : 草稿阶段结束时的 KV 缓存(直接继承,无需重算上下文!)
        S_edit         : AutoEdit 迭代轮数

    Returns:
        edited_tokens  : 修正后的轨迹 token [batch, L]
    """

    L = 16
    full_sequence = torch.cat([context_tokens, draft_tokens], dim=1)

    for k in range(S_edit):

        # ── 保存当前快照,以备回退 ────────────────────────
        kv_cache.save_snapshot()
        prev_sequence = full_sequence.clone()

        # ── 用当前完整 KV 做前向,得到每个位置的替换建议 ──
        logits = model.forward_with_kv(full_sequence, kv_cache)
        traj_logits = logits[:, -L:, :]
        probs  = softmax(traj_logits, dim=-1)

        # 每个位置的"替换 token"和"置信度"
        x_hat       = probs.argmax(dim=-1)   # [batch, L]
        confidence  = probs.max(dim=-1).values   # [batch, L]

        # ── 构造 commit mask:选低置信度的非目标点位置 ────
        goal_positions = get_goal_token_positions()
        commit_mask = torch.zeros(L, dtype=torch.bool)

        for i in range(L):
            if i in goal_positions:
                commit_mask[i] = False   # 目标点永远不替换
            elif confidence[0, i] < CONFIDENCE_THRESHOLD:
                commit_mask[i] = True    # 置信度低 → 标记为替换

        changed_positions = commit_mask.nonzero().squeeze(-1).tolist()

        if len(changed_positions) == 0:
            # 没有低置信度的位置,不需要修改,丢弃快照直接退出
            kv_cache._snapshots.pop()
            break

        # ── 执行替换 ──────────────────────────────────────
        for pos in changed_positions:
            full_sequence[0, -L + pos] = x_hat[0, pos]

        # ── 合并重写:只更新被替换位置的 K/V ─────────────
        kv_cache = recompute_kv_for_changed_positions(
            model             = model,
            token_sequence    = full_sequence,
            changed_positions = [-L + pos for pos in changed_positions],
            kv_cache          = kv_cache
        )

        # ── 可选的质量检查:如果修改后置信度整体更低,回退 ─
        new_logits     = model.forward_with_kv(full_sequence, kv_cache)
        new_confidence = softmax(new_logits[:, -L:, :], dim=-1).max(dim=-1).values.mean()
        old_confidence = confidence.mean()

        if new_confidence < old_confidence * 0.95:
            # 修改让模型"更不确定"了,撤销这一轮
            kv_cache.rollback()
            full_sequence = prev_sequence
        else:
            # 修改有效,丢弃快照(不再需要回退到这个点)
            kv_cache._snapshots.pop()

    edited_tokens = full_sequence[:, -L:]
    return edited_tokens


# ----------------------------------------------------------
#  合并重写 vs 朴素重算:计算量对比
# ----------------------------------------------------------
#
#  设:
#    seq_len      = n_ctx + L = 512 + 16 = 528
#    每步变化位置  = n_changed(草稿阶段约 5~6,AutoEdit 阶段约 2~3)
#    n_layers     = 28
#
#  朴素重算(每步):
#    对完整序列重算所有 K/V
#    FLOPs ∝ n_layers × seq_len × d_model²
#          = 28 × 528 × 2048²  ≈ 124G FLOPs
#
#  合并重写(每步):
#    只对 n_changed 个位置重算 K/V
#    FLOPs ∝ n_layers × n_changed × d_model²
#          = 28 × 6 × 2048²  ≈ 1.4G FLOPs
#    (其余 522 个位置直接读缓存,零 FLOPs)
#
#  理论加速:124G / 1.4G ≈ 88×
#
#  实际加速 1.28×(14.7ms → 11.5ms)的原因:
#    · GPU kernel 启动开销、内存带宽瓶颈等 overhead 不随 FLOPs 线性缩减
#    · KV 缓存的读写本身也有延迟
#    · 合并重写引入了额外的 mask 判断和条件写操作
#    但在绝对延迟上节省了 3.2ms/帧,对 31.8ms 的总延迟已是可观的贡献

为什么 K 和 V 只依赖自己的上一层输出?

Transformer 每一层做了什么?

每一层的计算可以拆成两步:

输入:上一层的隐状态矩阵 H^(l-1),形状 [seq_len, d_model]
      每一行是一个 token 的向量表示

Step 1:自注意力(Attention)
Step 2:前馈网络(FFN)

输出:本层的隐状态矩阵 H^(l),形状 [seq_len, d_model]

关键在于:K 和 V 是在哪里算出来的?

K 和 V 是在第 l 层的注意力计算开始之前,直接从上一层的输出线性变换得到的:

K^(l) = H^(l-1) · W_K    ← 矩阵乘法,逐行独立
V^(l) = H^(l-1) · W_V    ← 矩阵乘法,逐行独立
Q^(l) = H^(l-1) · W_Q    ← 矩阵乘法,逐行独立

逐行独立是关键词。把上面的矩阵乘法写成逐行的形式:

K^(l)_i = h^(l-1)_i · W_K    ← 只用了第 i 行(第 i 个 token 的向量)
V^(l)_i = h^(l-1)_i · W_V    ← 同上

这个线性变换里完全没有其他 token 的信息,所以:

K^(l)_iV^(l)_i 只依赖 h^(l-1)_i,不依赖 h^(l-1)_jj ≠ i)。

那"互相影响"发生在哪里?

发生在注意力的输出计算里,也就是 Q 和 K/V 做完点积之后:

注意力输出:
  A^(l)_i = softmax( Q^(l)_i · (K^(l))ᵀ ) · V^(l)
                      ↑                      ↑
              这里 token i 的 Q        这里混入了所有 token 的 V
              去和所有 token 的 K 做点积

所以注意力输出 A^(l)_i 确实包含了所有其他 token 的信息,但 K 和 V 本身在被计算出来的时候,还没有经历这一步混合。

用一张流程图看清楚顺序

第 l 层:

H^(l-1)
  │
  ├─[×W_K]─→ K^(l)_i = h^(l-1)_i · W_K   ← 此时只有自己
  ├─[×W_V]─→ V^(l)_i = h^(l-1)_i · W_V   ← 此时只有自己
  └─[×W_Q]─→ Q^(l)_i = h^(l-1)_i · W_Q   ← 此时只有自己
                │              │
                └──────┬───────┘
                       ↓
              Attention(Q, K, V)            ← 此处发生混合
                       ↓
                    A^(l)_i                 ← 此时包含所有 token 的信息
                       ↓
                    FFN(A^(l)_i)
                       ↓
                    H^(l)_i                 ← 下一层的输入

用一个比喻来理解

想象一场圆桌会议,每个人(token)在发言前先准备自己的问题卡(Q)、名片(K)、资料包(V)。

准备这三样东西的时候,每个人只看自己昨天(上一层)的笔记,不看别人的。所以名片和资料包只属于自己。

然后会议开始:每个人拿着自己的问题卡去和所有人的名片做匹配,再根据匹配程度加权收集大家的资料包。混合发生在这一步,而不是准备名片和资料包的时候。


相关论文

  • [[DiffusionDriveV2 论文阅读笔记]]
  • [[DriveTransformer 论文阅读笔记]]
  • [[RAD-2 论文阅读笔记]]
  • [[World4Drive - 无需感知标注的端到端世界模型]]

总结一句话:ReflectDrive-2 教会了自动驾驶模型像人类老司机一样——先打草稿,再反思修改,最后稳稳开车。 而这种反思能力之所以有效,不是因为模型"天生就会改",而是因为强化学习让起草和修改成了真正的搭档,共同对最终结果负责。