端到端自动驾驶的视频生成与轨迹规划
🎯 一句话概括
Epona 是一个自回归扩散世界模型,它像拍连续剧一样根据历史画面预测未来,同时用扩散模型保证每一帧画质高清——不仅能"脑补"出未来 2 分钟的驾驶场景,还能学会"红灯停、避让行人"等物理规则。
🧠 核心设计理念
为什么需要 Epona?
在自动驾驶领域,存在两类模型各有优劣:
| 模型类型 | 优势 | 劣势 |
|---|---|---|
| 扩散模型 | 画质逼真、细节丰富 | 短视,难以生成长视频,不懂数理逻辑 |
| 自回归 Transformer | 懂因果、能长程推理 | 图像压缩粗糙,画质模糊 |
Epona 的思路:为什么不能兼得?于是采用 “自回归 + 扩散” 混合架构:
- 像写连续剧一样(自回归)根据历史预测未来
- 同时用扩散模型保证每一帧画质高清
三大核心创新
- 分工明确:时空处理分离,效率大幅提升
- 异步生成:轨迹规划和视频生成分开进行
- 连锁前向训练:解决误差累积问题,能生成长达 2 分钟的视频
🏗️ 架构详解
Epona 由三个核心模块组成,像一个精密配合的团队:
┌─────────────────────────────────────────────────────────────┐
│ Epona 架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 历史 T 帧 ──┐ │
│ │ ┌─────────┐ │
│ 历史动作 ────┼───►│ MST │──► 特征 F │
│ │ │(记忆大师)│ │ │
│ ┘ └─────────┘ │ │
│ │ │
│ ┌──────────┴──────────┐ │
│ ▼ ▼ │
│ ┌──────────┐ ┌──────────┐ │
│ │ TrajDiT │ │ VisDiT │ │
│ │(领航员) │ │ (画师) │ │
│ └────┬─────┘ └────┬─────┘ │
│ │ │ │
│ ▼ ▼ │
│ 未来轨迹 下一帧画面 │
│ │
└─────────────────────────────────────────────────────────────┘
📚 2.1 MST (Multimodal Spatiotemporal Transformer)
🎭 角色:超级记忆大师
MST 的任务是将过去复杂的视频画面和驾驶操作,压缩成一个精炼的特征向量。就像一个记忆力超群的人,看一眼就能记住所有关键信息。
输入预处理
原始输入:
├── 视觉:过去 T 帧 (如 10 帧) 图像,分辨率 512×1024
└── 动作:每帧对应的历史轨迹(速度、方向盘转角等)
DCAE 压缩处理:
├── 图像压缩 16 倍:512×1024 → 32×64 特征图
├── 铺平成 Token:32×64 = 2048 个视觉 Token (记作 L)
└── 动作投影:动作向量映射到同维度 Token
最终输入张量 E:
├── 形状:[Batch, T, (L+3), D]
├── L+3 = 2048 个视觉 Token + 3 个动作 Token
└── D = 特征维度
🔄 时空分离处理 —— “先看时间,再看空间”
MST 不是同时处理时空,而是交替进行,像这样:
步骤 A:时间层 —— “串联历史”
目标:让图像中同一个坐标位置的像素点,去查阅自己在过去 $T$ 帧的变化。
# 输入变换
原始形状:[B, T, S, D] # S 是空间 Token 数 L+3
变换后:[(B * S), T, D] # 把空间维度和 Batch 绑在一起
# 物理含义
现在模型眼里的"一个样本",不再是整段视频,
而是视频中某个特定位置的像素点随时间的变化序列。
# 关键技术:Causal Mask(因果遮罩)
第 5 帧的像素只能看第 1, 2, 3, 4 帧的自己,不能偷看第 6 帧。
步骤 B:空间层 —— “理解当下”
目标:把每一帧看作独立的图片,让图像里的车、路、树木以及动作指令 Token 进行全注意力交互。
# 输入变换
变换后:[(B * T), S, D] # 把时间维度和 Batch 绑在一起
# 多模态融合
视觉信息和动作意图在此处深度融合。
📍 3D 位置编码 (EmbedND)
Epona 使用分块对角旋转位置编码 (RoPE) 来编码时空位置:
def EmbedND(dim, theta, axes_dim):
"""
为视频中的每个像素点生成 3D 位置嵌入
维度分配示例:[Time: 2维, Height: 2维, Width: 2维]
"""
for i, (pos, dim) in enumerate(zip(axes, axes_dim)):
out.append(rope(pos, dim, theta))
return torch.cat(out, dim=-1)
形象例子:假设要给坐标 (t=5, h=10, w=20) 的像素编码:
循环 1 (Time): Embed(5) → [0.1, 0.9]
循环 2 (Height): Embed(10) → [0.5, 0.5]
循环 3 (Width): Embed(20) → [0.8, 0.2]
最终拼接:[0.1, 0.9, 0.5, 0.5, 0.8, 0.2]
这样,最终向量同时包含时间、高度和宽度信息,互不干扰。
输出
经过 $N$ 层循环后,提取序列中最后一帧的特征 $\mathbf{F}$。这是包含丰富历史语义和当前状态的高维特征向量,作为后续两个模块的基石。
🚗 2.2 TrajDiT (Trajectory Planning DiT)
🎭 角色:决策中枢 & 老司机
拿到 MST 给的局面 $\mathbf{F}$,在不生成图像的情况下,极速规划出未来 3 秒怎么开。
架构:双流融合
这是一个专门"画线"(轨迹)的轻量级扩散模型。
输入准备:
├── 条件 (Cond):来自 MST 的特征 F
└── 噪声 (Input):随机高斯噪声 x_T(代表未来轨迹的草稿)
双流阶段 (Dual-Stream Phase):
├── 环境流:处理特征 F
├── 轨迹流:处理噪声轨迹
└── 通过 Cross-Attention 交换信息
单流阶段 (Single-Stream Phase):
├── 两条流拼接,深度混合推理
└── 确保轨迹与环境严丝合缝
🔧 Modulation 调制机制
Modulation 是将时间嵌入转化为神经网络控制参数的关键组件:
class Modulation:
def __init__(self, dim, double):
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, dim * self.multiplier)
def forward(self, vec):
out = self.lin(vec.silu()).chunk(self.multiplier, dim=-1)
return out # 返回 (shift, scale, gate) 组
参数含义:
- Shift (β):偏移量,平移特征
- Scale (γ):缩放因子,拉伸/压缩特征
- Gate (α):门控值,控制残差连接强度
📊 DoubleStreamBlock vs SingleStreamBlock
| 特性 | DoubleStreamBlock | SingleStreamBlock |
|---|---|---|
| 数据流 | 两条独立流 (环境+轨迹) | 一条混合流 |
| 调制参数 | 每条流 6 个,共 12 个 | 仅 3 个 |
| 结构 | 串行逻辑 | 并行逻辑 |
| 用途 | TrajDiT 前期,保护环境特征 | TrajDiT 后期/VisDiT,高效推理 |
在 DoubleStreamBlock 中:
# 轨迹流
img_mod1 → 控制 Attention 的 AdaLN 和门控
img_mod2 → 控制 MLP 的 AdaLN 和门控
# 环境流
cond_mod1 → 控制 Attention 的 AdaLN 和门控
cond_mod2 → 控制 MLP 的 AdaLN 和门控
🎨 2.3 VisDiT (Next-frame Prediction DiT)
🎭 角色:超写实画师
根据 MST 的特征和 TrajDiT 的轨迹规划,生成下一帧图像。
输入准备
画布噪声:随机高斯噪声潜变量 Z_{T+1}
环境参考:MST 的特征 F
动作指令:TrajDiT 预测的轨迹(关键!)
核心机制:动作调制
轨迹向量转化为控制神经网络的旋钮参数:
# 轨迹向量转化为缩放因子和偏移量
Input = Input * Scale(a) + Shift(a)
# 通过 AdaLN 注入到 Transformer 每一层
- 如果规划是"左转",调制会强迫网络关注左侧特征
- 保证生成画面与规划动作一致
⏱️ 分辨率感知的时间偏移 (get_schedule)
这是一个**“智能时间管理大师”**:
def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15):
timesteps = torch.linspace(1, 0, num_steps + 1) # 基础进度条
if shift:
mu = get_lin_function(base_shift, max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps) # 偏移
return (1 - timesteps).tolist()
为什么需要它?
- 痛点:画大图比画小图更难,需要在"宏观构图"阶段多花点时间
- 解决方案:根据序列长度自动调整时间表
- 序列长(大图):在高噪声阶段停留更久,先定大轮廓
- 序列短(小图/轨迹):匀速搞定即可
在 Epona 中:
- TrajDiT:序列短,时间表几乎不偏移
- VisDiT:序列长(2048 Token),显著偏移
时序感知解码
- 使用 Temporal-aware DCAE Decoder 解压潜变量
- 参考上一帧的潜变量,消除频闪和抖动
- 输出 512×1024 高清图像
🎓 关键技术细节
3.1 傅立叶嵌入 (timestep_embedding)
将低维动作数据映射到高维空间,增强神经网络对细微变化的感知能力:
def timestep_embedding(t, dim, max_period=10000):
t = time_factor * t
half = dim // 2
freqs = exp(-log(10000) * arange(0, half) / half)
args = t[:, None] * freqs[None]
embedding = cat([cos(args), sin(args)], dim=-1)
return embedding
原理:
- 原始输入:低维向量
- 输出:高维特征,包含从低频到高频的丰富信号
- 效果:神经网络能"看到"微小变化
3.2 RoPE vs 正弦编码
| 特性 | 正弦编码 | RoPE |
|---|---|---|
| 相对位置感知 | 弱 | 强(点积只取决于相对距离) |
| 长度外推性 | 差 | 好(周期性,不死记硬背) |
| 维度解耦 | 难 | 优雅(分块对角矩阵) |
3.3 连锁前向训练 (Chain-of-Forward Training)
痛点:自回归模式的误差累积——第一帧歪一点,第 100 帧就崩了
解决方案:
- 训练时偶尔用模型自己生成的(有瑕疵的)预测结果作为下一轮输入
- 模型被迫学会自我修正
效果:能生成长达 2 分钟不崩坏的视频
🔄 完整推理流程
def step_eval(latents, rel_pose, rel_yaw):
# 1. MST 编码:压缩历史信息
stt_features, pose_emb = model.evaluate(latents, poses, yaws)
# 2. TrajDiT 规划:决定未来怎么走
noise_traj = randn(...)
predict_traj = traj_dit.sample(noise_traj, traj_ids, stt_features)
# 3. 提取下一步动作
predict_pose, predict_yaw = predict_traj[:, 0:1, ...]
pose_emb = model.get_pose_emb(predict_pose, predict_yaw)
# 4. VisDiT 生成:脑补下一帧画面
noise = randn(...)
predict_latents = dit.sample(noise, img_ids, stt_features, pose_emb)
return predict_traj, predict_latents
📊 实验成果
| 指标 | 结果 |
|---|---|
| 视频生成 FVD | 7.4 (优于 Vista 7.9, 远超 DriveGAN 73.4) |
| 视频长度 | 2 分钟 且逻辑连贯 |
| 物理理解 | 自学懂"红灯停"、“避让行人"等规则 |
| 规划能力 | NAVSIM 评测超过多个专门规划模型 |
💡 总结
Epona 通过 MST(压缩理解)、TrajDiT(规划导航) 和 VisDiT(受控绘图) 三者的精密配合,实现了从"看懂路"到"决定怎么开"再到"脑补未来后果"的完整闭环。
它不仅是一个视频生成器,更是一个具备潜力的端到端自动驾驶大脑。
📎 相关链接
- 论文:Epona: Autoregressive Diffusion World Model for Autonomous Driving
- 相关工作:[[World4Drive - 无需感知标注的端到端世界模型]]、[[LAW - Latent World Model for E2E Driving]]