Mamba遇见扩散模型!DiM:无需微调,高分辨图像生成更高效!
点击下方卡片,关注“CVer”公众号
点击下方卡片,关注“CVer”公众号
AI/CV重磅干货,第一时间送达
AI/CV重磅干货,第一时间送达
添加微信:CVer5555,小助手会拉你进群!
扫描下方二维码,加入CVer学术星球!可以获得最新顶会/顶刊上的论文idea和CV从入门到精通资料,及最前沿应用!发论文/搞科研/涨薪,强烈推荐!
添加微信:CVer5555,小助手会拉你进群!
扫描下方二维码,加入CVer学术星球!可以获得最新顶会/顶刊上的论文idea和CV从入门到精通资料,及最前沿应用!发论文/搞科研/涨薪,强烈推荐!
导读
本文提出了一种新的基于 Mamba 的扩散模型 DiM,用于高效的高分辨率图像生成。Mamba 本是用于处理一维信号的模型,作者提出了几种有效的设计来使其能够对二维图像进行建模。
本文目录
1 DiM:高效高分辨率图像生成的 Diffusion Mamba
(来自香港大学,华为诺亚方舟实验室)
1 DiM 论文解读
1.1 用 Mamba 架构进行高分辨率图像生成
1.2 状态空间模型
1.3 Diffusion Mamba 架构
1.4 训练和推理策略
1.5 实验设置
1.6 效率分析
1.7 实验结果
太长不看版
扩散模型在图像生成方面取得了巨大成功,Backbone 从 U-Net 演变到 Vision Transformer。然而, Transformer 的计算成本与 token 的数量成二次方,在处理高分辨率图像时面临重大挑战。本文提出 Diffusion Mamba (DiM),它结合了 Mamba 的效率,且具有扩散模型的表达能力,以实现高效的高分辨率图像合成。Mamba 是一种基于状态空间模型 (State Space Models, SSM) 的序列模型。
为了解决 Mamba 不能泛化到 2D 信号的挑战,作者提出了几种架构设计,包括多方向扫描、每行和列末尾的 learnable padding tokens 以及轻量级局部特征增强。DiM 架构可以高效地生成高分辨率图像。此外,为了进一步提高 DiM 高分辨率图像生成的训练效率,作者研究了在低分辨率图像 (256×256) 上预训练 DiM 的 "weak-to-strong" 训练策略,然后在高分辨率图像上微调它 (512×512)。作者进一步探索了 training-free 的上采样策略,使模型能够生成更高分辨率的图像 (例如 1024×1024 和 1536×1536),而无需进一步微调。实验证明了 DiM 的有效性和效率。
本文做了哪些具体的工作
提出了一种新的基于 Mamba 的扩散模型 DiM,用于高效的高分辨率图像生成。Mamba 本是用于处理一维信号的模型,作者提出了几种有效的设计来使其能够对二维图像进行建模。 为了解决高分辨率图像训练的高成本,作者研究了微调在低分辨率图像上预训练的 DiM 以进行高分辨率图像生成的策略。此外还探索了 training-free 的上采样方案,使模型在无需进一步的微调的情况下生成更高分辨率的图像。 在 ImageNet 和 CIFAR 上的实验证明了 DiM 在高分辨率图像生成中的训练效率、推理效率和有效性。
1 DiM:高效高分辨率图像生成的 Diffusion Mamba
论文名称:DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis (Arxiv 2024.05)
论文地址:
https://arxiv.org/pdf/2405.14224
代码链接:
http://github.com/tyshiwo1/DiM-DiffusionMamba/
1.1 用 Mamba 架构进行高分辨率图像生成
扩散模型在图像生成方面取得了巨大的成功。由于 Transformer 架构的有效性和可扩展性,扩散模型的 Backbone 已经从以 U-Net[1]为代表的卷积神经网络发展到 Vision Transformer[2][3][4][5]。基于 Transformer 的扩散模型将图像编码为 latent 特征图,再把 latent 特征图分成 Patches,再把这些 Patches 投影为 tokens。然后,应用 Transformer 对图像 tokens 进行去噪。但是,Transformer 中的 Self-attention 层的复杂度与 tokens 的数量呈二次方关系,使得高分辨率图像生成的计算成本面临重大挑战。
Mamba[6]是一种基于状态空间模型 (State Space Models, SSM) 的序列模型 Backbone,在语言、音频和基因组学等几种模式中显示出显著的有效性和效率。Mamba 实现了与 Transformer 相当的性能,且具有更好的推理时间和效率。与 Transformer 的二次计算复杂度相比,Mamba 在长序列建模中显示出巨大的前景,因为 Mamba 的计算复杂度与 token 的数量成线性关系。Mamba 的这些特性促使本文作者将 Mamba 作为扩散模型的新的 Backbone 引入,尤其是对于高效的高分辨率图像生成。
然而,当将 Mamba 与扩散模型相结合以进行高分辨率图像生成时,会出现一些挑战。主要挑战是 Mamba 的因果序列建模与图像的二维 (2D) 数据结构之间的不匹配。Mamba 架构是为序列信号的一维 (1D) 因果建模而设计的,不能直接用于建模二维图像 tokens。一个简单的解决方案是使用光栅扫描顺序将 2D 数据转换为 1D 的序列。但是,它将每个位置的感受野限制为只有光栅扫描顺序中的先前位置。此外,在光栅扫描顺序中,当前行的结尾后面是下一行的开始,它们之间不共享空间连续性。第二个挑战是,尽管 Mamba 具有高效推理的优势,但在高分辨率图像上训练基于 Mamba 的扩散模型的训练代价依然昂贵。
为了缓解第一个挑战,本文作者提出了 Diffusion Mamba (DiM),如下图 2 所示,这是一种基于 Mamba 的扩散模型 Backbone,用于高效的高分辨率图像生成。在 DiM 中,作者遵循 Diffusion Transformer 的做法将图像编码为 Patch 特征。然后,作者使用 Mamba 架构作为 Backbone 来建模特征。为了避免补丁之间的单向因果关系并赋予每个 token 全局感受野,作者设计了 Mamba Block 交替执行四个扫描方向。此外,作者在扫描顺序相邻的两个标记之间插入 learnable padding tokens,但在空间域中不相邻,从而允许 Mamba 块识别图像边界并避免误导序列模型。作者还将 3×3 Depth-Wise Convolution 添加到网络的输入层和输出层,以增强生成图像的局部相干性。此外,作者在浅层和深层之间添加了长跳跃连接[1][3]。基于 Transfo,以将低级信息传播到高级特征,这一点已经被证明有利于扩散模型中的像素级预测目标。
为了解决 DiM 高分辨率图像生成训练效率的挑战,作者探索了资源高效的方法,使得在低分辨率图像上预训练的 DiM 模型可以完成高分辨率的图像生成。作者首先观察到在低分辨率图像上预训练的 DiM 可以为高分辨率图像生成提供合理的先验。因此,作者探索了 "weak-to-strong" 训练策略[7]。作者首先在低分辨率图像上训练 DiM,然后使用预训练模型作为初始化来有效地微调高分辨率图像。该策略大大降低了高分辨率图像生成的训练时间成本。作者还探索了 training-free 的上采样方法,使 DiM 进一步适应更高分辨率的图像,而无需进一步微调。
1.2 状态空间模型
状态空间模型 (State Space Models, SSM) 常用于编码和解码一维序列输入。在连续时间 SSM 中, 首先将输入信号 编码为隐藏状态向量 , 然后根据以下常微分方程 (Ordinary Differential Equations, ODEs) 解码为输出信号 :
其中 表示 的导数, 表示 SSM 的权重。通常, 自然语言和二维视觉的输入是离散信号, 因此 Mamba 利用零阶保持 (Zero-order Hold, ZOH) 规则进行离散化。因此, 上述 ODE 可以反复求解:
其中 也是一个模型参数。最近, Mamba 提出通过将时不变参数更改为时变来提高 SSM 的灵活性。这种修改涉及用依赖于输入 的动态权重替换静态模型权重 。具有输入相关参数的过程称为选择性扫描。
1.3 Diffusion Mamba 架构
Mamba 主要用于处理一维输入,因此 Mamba 很难在没有任何修改的情况下学习图像的二维数据结构。因此,作者提出了一些新的架构设计,使 DiM 能够处理空间结构。
总体架构
如图2所示,DiM 框架可以处理有噪声的二维 (2D) 输入,比如图像或者 latent 的特征,同时需要输入 time step 和 class condition。这种噪声输入可以被视为由对应于输入时间步长的特定高斯噪声级别扰动的干净信号。噪声输入首先被分成 2D Patches,每个 Patch 可以通过全连接层转换为高维特征向量。接下来,这些 Patches 被送入 3×3 Depth-Wise Convolution 层,其中局部信息被注入到 Patches 中。Patches 也在行和列的末尾用可学习的 tokens 填充,允许模型在一维顺序扫描期间感知二维空间结构。然后,使用图2所示的四个扫描模式之一,将 Patch tokens 展平为 Patch 序列。time step 和 class condition 也通过全连接层转换为 tokens,然后附加到序列中。随后,序列被送入 Mamba Blocks 进行扫描。此外,作者还在浅层和深层之间添加了长跳跃连接,以将低级信息传播到高级特征,这也被证明有利于扩散模型中的像素级预测目标。
扫描模式
全局感受野对于本文模型有效地捕获图像中的空间结构至关重要。在单个光栅扫描方向上扫描图像 Patches 会导致单向有限的感受野。例如,左上角的第一个扫描 Patch 永远不会聚合来自其他 Patch 的信息。为了使每个 Patch 具有全局感受野,作者在不同的模型块中采用了不同的扫描模式。具体来说,如图2所示,在第1个 Block 中,采用行主扫描,即逐行扫描图像补丁序列,每一行从左到右水平扫描,然后移动到下一行。在第2个 Block 中,反转序列顺序并以相同的方式扫描序列。在随后的 Block 中,作者以正向和反向顺序执行列主扫描。在遍历所有扫描模式后,在下一个模型块中再次循环它们。
可学习的 padding token
图像的空间结构的学习可能会被光栅扫描破坏。具体来说,当将图像展平为 Patch 序列时,图像一行中最右边的补丁变得与第二行最左边的补丁相邻。然而,这两个特征向量所代表的内容可能存在很大差异。因为图像有固有的连续性和空间结构,但是扫描的方式与这种结构相矛盾,从而阻碍了学习的过程。为了缓解这个问题,作者在每一行或者每一列的末尾增加可学习的 padding token,使得模型意识到 End-Of-Line (EOL)。
轻量级局部特征增强
图像的局部结构会会被扫描的序列化过程所破坏。比如在行主扫描中, 行 和列 处的 Patch 不再与行 和列 处的 Patch 相邻。此外, 由于 Mamba 专为极端的效率优化而设计, 因此选择通过在网络的开头和结尾添加几个轻量级模块来增强局部结构, 而不改变 Mamba Block 本身。
具体来讲,作者引入了两个 3×3 Depth-Wise Convolution 层。在将 tokens 输入给 Mamba Block 之前,在 Patchify 层之后插入一个卷积层。在 Unpatchify 和输出层之前,在所有 Mamba Block 之后插入另一个卷积层。这些轻量级的 Depth-Wise Convolution 层为 DiM 提供了对 2D 局部连续性的认识。
1.4 训练和推理策略
尽管推理效率不错,但在高分辨率图像上训练 DiM 仍需要大量的时间和计算资源。
"Weak-to-strong" 训练和微调
从头开始训练高分辨率图像的扩散模型需要大量的时间和计算资源。作者观察到,在低分辨率图像上预训练的 DiM 可以为高分辨率训练提供粗略的初始化,如下图3所示。因此,作者考虑了一种 "Weak-to-strong" 的训练策略,在低分辨率图像上从头开始训练 DiM 模型,然后对更高的分辨率进行微调。在微调期间,我们将图像的长度和宽度提高2倍。该策略大大降低了使用 DiM 训练高分辨率图像生成器的计算成本。
Training-free 上采样
极高分辨率图像不容易获得,因此很难微调到 DiM 以更高的分辨率。因此,本文探索了模型 Training-free 的上采样能力。例如,直接使用在 512×512 数据集[8]上训练的模型来生成 1024×1024 图像。然而,在 DiM 中执行无训练的超分辨率图像生成并非易事。作者观察到,直接用更高分辨率的高斯噪声馈送给网络会导致具有重复模式、损坏的全局结构和折叠空间布局的图像。只有局部结构和细节表现出相对较好的质量。为了生成更好的全局结构,作者在早期扩散时间步 (比如前 30% 的时间步) 使用上采样引导 (Upsample Guidance)[9]:
其中, 表示本文基于噪声预测的 Mamba 的模型, 表示缩放因子 (比如 ), 表示尺度 最近的上采样算子, 表示步幅 的平均池化算子 (下采样), 表示噪声输入, 表示输入扩散时间步长, 表示其信噪比为 的信噪比为 倍的时间步长, 表示在每个时间步校准预测噪声的整体功率的系数, 表示上采样指导的权重。在后来的扩散时间步中, 作者直接将更高分辨率的噪声输入输入到 DiM 中进行噪声预测。
1.5 实验设置
如下图4所示为 DiM 模型的3个版本,模型大小不同。
输入大小设置为 32×32,没有 image auto-encoder[10]。Mamba Block 的超参数作者遵循标准设置[6]。将 ImageNet 和 CIFAR 上训练的 DiM 的 Patch Size 设置为 2×2。
所有的训练实验都在 8 A100-80G 上执行。继之前的工作 U-ViT 之后,作者使用相同的 DDPM scheduler、预训练的图像自动编码器和 DPM-Solver。使用随机翻转作为数据增强。学习率设置为 2×10−4。还使用速率为0.9999 的 EMA。
评价指标为 FID-50K[11],数据集为 CIFAR 和 ImageNet。作者还使用 classifier-free guidance 进行评估,计算 FID 的指导权重与 U-ViT 中的指导权重相同。考虑到有限的 GPU 资源,在 ImageNet 256×256 上进行预训练时,将 DiM-Large 和 DiM-Huge 的 Batch Size 设置为 1024 和 768。在 ImageNet 512 × 512 上微调 DiM-Huge 时,将 Batch Size 设置为 240,梯度累积。
在 ImageNet 上,作者首先在 256×256 分辨率下预训练超过 300K iterations 的 DiM 模型。然后,以 512×512 的分辨率微调预训练模型。为了在不增加训练成本的情况下实现更高的分辨率,作者进一步使用 training-free 的上采样技术生成 1024×1024 和 1536×1536 图像,DiM-Huge 在 512×512 分辨率下训练。
1.6 效率分析
作者检查 DiM 的效率,并将其与 Transformer 主干进行比较。单个选择性扫描比 FlashAttention V2[12]更高效。然而,为了保持相似数量的参数,标准 Mamba 的 Block 数是 Transformer 的两倍。扫描的加倍增加了计算复杂度。此外,作者提出的包括扫描模式的切换在内的模块也会造成轻微的延迟。为了比较 Transformer 和 Mamba 在图像生成方面的实际效率,作者在单个 H800 GPU 上进行了实验。
作者在图5中展示了本文模型,U-ViT 和 Mamba Baseline 的推理速度。这些模型具有相似的参数量 (0.9B) 和相同的 2×2 Patch Size。可以看到,原始的 Mamba 基线和 DiM 在分辨率低于 1024×1024 的情况下比优化良好的基于 Transformer 的模型慢。然而,在分辨率高于 1280×1280 情况下,DiM 比 Transformer 更快,这要归功于其线性复杂度。
而且,DiM 的效率仅略低于 Mamba Baseline,这表明作者添加到原始 Mamba 的设计使 Mamba 适应 2D 图像,但不会造成较大的额外计算成本。
1.7 实验结果
ImageNet 数据集生成质量
作者选择一组生成的图像进行可视化。结果表明,在 ImageNet 上预训练的 DiM-Huge 可以生成高质量的 256×256 图像,如图 6(b) 所示。本文模型在分辨率为 512×512 的 ImageNet 上微调的模型也显示出出色的性能,如图 6(a) 所示。
可以使用在 512×512 ImageNet 上训练的模型直接生成 1024×1024 和 1536×1536 图像。如图7所示,即使分辨率增加到训练的3倍,本文模型仍然能够生成具有上采样引导的视觉上吸引人的图像。
ImageNet 256×256 预训练数值结果
考虑到有限的计算资源和时间,作者只能在最多 319 million 图像样本上训练模型。作者将 DiM 与其他基于 Transformer 和 SSM 的扩散模型进行了比较,如下图8所示。值得注意的是,在对 319 million 个图像样本进行训练后,DiM-Huge 在 FID-50K 上可以达到 2.40 的分数。在使用 U-ViT (319M vs 500M) 63% 的训练数据的情况下,本文模型的性能与其他基于 Transformer 的扩散模型相当,即在 FID-50K 上仅差约 0.1。此外,与 DiffuSSM-XL 相比,基于 Mamba 的扩散模型的 GFLOPs 要小得多,即 DiM 需要更少的资源进行推理。
ImageNet 512×512 微调数值结果
在 512×512 图像样本进行训练需要大量的计算资源。此外,这种较大的分辨率在训练和推理过程中造成了不可忽略的延迟,如图9所示。因此,作者没有从头开始训练,而是从在 ImageNet 256×256 上进行预训练之后的模型为初始化,来微调本文的 DiM-Huge。作者只使用了 U-ViT 的 512×512 训练数据的 3% (15M vs 500M),DiM-Huge 就达到了 3.94 FID50K。尽管仍然远未达到最佳性能,但 DiM-Huge 能够产生视觉上吸引人的 512×512 图像,如图7(a)所示。
CIFAR-10 的实验结果如下图10所示。本文方法可以与其他具有相似参数量的方法相比实现相当的性能。
消融实验结果
作者在 CIFAR-10 数据集上进行了消融实验,FIDs 结果如图11所示。其中第一行包含最佳性能模型的结果,其他行的性能对应于没有某些组件的模型。
根据第一行和最后一行,多次扫描模式会对结果有帮助,说明全局感受野的重要性。
作者还发现长距离 skip-connection 有助于收敛。
此外,卷积和 padding token 也有助于提升性能。
参考
^abU-Net: Convolutional Networks for Biomedical Image Segmentation
^Scalable Diffusion Models with Transformers
^abAll are worth words: A vit backbone for diffusion models
^Pixart-alpha: Fast training of diffusion transformer for photorealistic text-to-image synthesis
^Scaling rectified flow transformers for high-resolution image synthesis
^abMamba: Linear-Time Sequence Modeling with Selective State Spaces
^PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation
^ImageNet: A large-scale hierarchical image database
^Upsample guidance: Scale up diffusion models without training
^Auto-Encoding Variational Bayes
^GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium
^FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
何恺明在MIT授课的课件PPT下载
在CVer公众号后台回复:何恺明,即可下载本课程的所有566页课件PPT!赶紧学起来!
CVPR 2024 论文和代码下载
CVPR 2024 论文和代码下载
在CVer公众号后台回复:CVPR2024,即可下载CVPR 2024论文和代码开源的论文合集
Mamba、多模态和扩散模型交流群成立
扫描下方二维码,或者添加微信:CVer5555,即可添加CVer小助手微信,便可申请加入CVer-Mamba、多模态学习或者扩散模型微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer、NeRF、3DGS、Mamba等。
一定要备注:研究方向+地点+学校/公司+昵称(如Mamba、多模态学习或者扩散模型+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群
▲扫码或加微信号: CVer5555,进交流群
CVer计算机视觉(知识星球)来了!想要了解最新最快最好的CV/DL/AI论文速递、优质实战项目、AI行业前沿、从入门到精通学习教程等资料,欢迎扫描下方二维码,加入CVer计算机视觉(知识星球),已汇集近万人!
▲扫码加入星球学习
▲点击上方卡片,关注CVer公众号
整理不易,请赞和在看
▲扫码或加微信号: CVer5555,进交流群
CVer计算机视觉(知识星球)来了!想要了解最新最快最好的CV/DL/AI论文速递、优质实战项目、AI行业前沿、从入门到精通学习教程等资料,欢迎扫描下方二维码,加入CVer计算机视觉(知识星球),已汇集近万人!
▲扫码加入星球学习
▲点击上方卡片,关注CVer公众号
整理不易,请赞和在看
微信扫码关注该文公众号作者