华中科技大学推出VA-VAE和LightningDiT框架:在不牺牲重建质量的前提下,显著提升生成性能

华中科技大学的研究团队聚焦于潜在扩散模型(latent diffusion models)中的优化困境,即视觉分词器(visual tokenizer)中每令牌特征维度的增加虽能提升重建质量,但会降低生成性能,且需要更大的扩散模型和更多训练迭代来实现可比的生成效果。例如,我们有一个图像生成任务,需要在保持图像细节的同时,快速生成高质量的图像。传统的LDMs在提高视觉标记器的特征维度以增强重建质量时,会遇到生成性能下降的问题。该团队提出的VA-VAE(Vision foundation model Aligned Variational AutoEncoder)和LightningDiT框架,可以在不牺牲重建质量的前提下,显著提升生成性能。

主要功能

  1. 优化视觉分词器:提出 VA - VAE(Vision foundation model Aligned Variational AutoEncoder),通过与预训练视觉基础模型对齐,在训练视觉分词器时优化其潜在空间,在保持高重建质量的同时提升生成性能,有效解决了潜在扩散模型中重建与生成之间的优化困境。
  2. 提升扩散模型效率:构建了 LightningDiT,这是一个增强的 Diffusion Transformers(DiT)基线模型,结合了改进的训练策略和架构设计,能够在高维潜在空间中更快地收敛,显著提高了扩散模型的训练效率,如在 ImageNet 256×256 生成任务中,仅用 64 个训练周期就达到了 2.11 的 FID 得分,相比原始 DiT 模型收敛速度提升超过 21 倍。

主要特点

  1. 创新的 VF Loss 设计:VA - VAE 中的 VF Loss 由边际余弦相似度损失(Marginal Cosine Similarity Loss)和边际距离矩阵相似度损失(Marginal Distance Matrix Similarity Loss)组成,通过精心设计的联合重建和对齐损失,正则化高维潜在空间,同时保持一定的灵活性,避免过强约束,有效引导视觉分词器学习更适合生成任务的潜在表示。
  2. 训练优化策略多样:LightningDiT 采用了多种优化策略,包括计算层面的加速技术(如 torch.compile 和 bfloat16 训练)、扩散优化方法(如 Rectified Flow、logit normal 分布采样和速度方向损失)以及模型架构层面的优化(如 RMSNorm、SwiGLU 和 RoPE),这些策略相互配合,进一步提升了模型的训练效率和性能。

工作原理

VA - VAE 的训练

  • 利用视觉基础模型对齐:在训练视觉分词器时,通过 VF Loss 将其潜在空间与预训练的视觉基础模型(如 DINOv2、MAE 等)进行对齐。对于给定图像,同时经过视觉分词器编码器和冻结的视觉基础模型处理,得到图像潜在表示和基础视觉表示,然后通过线性变换将视觉分词器的潜在表示投影到与基础模型相同的维度,计算两者之间的边际余弦相似度损失和边际距离矩阵相似度损失,以优化潜在空间。
  • 自适应权重调整:采用自适应加权机制,在反向传播前计算重建损失和 VF Loss 在编码器最后卷积层上的梯度,根据梯度比值设置自适应权重,确保 VF Loss 和重建损失对模型优化有相似影响,从而快速对齐不同 VAE 训练管道中的损失尺度,同时仍可手动调整超参数进一步优化性能。

LightningDiT 的构建与训练

  • 模型架构与训练设置:以 DiT - XL/2 为基础模型,采用 SD - VAE(f8d4 规格)作为视觉分词器,利用预提取的分词器潜在特征在 ImageNet 上进行训练。设置 DiT 的 patch size 为 1,确保整个系统下采样率为 16,其他架构参数与 DiT 一致。训练过程中采用多种优化策略,如计算层面的加速、扩散优化和模型架构优化等,且观察到不同加速策略间的非正交性,进行了合理组合。
  • 渐进训练策略:对于使用 VF Loss(DINOv2)的分词器,采用渐进训练策略延长训练时间至 125 个周期,以获得更强生成能力的 VA - VAE,然后使用 LightningDiT - XL 在特定参数设置下训练 800 个周期,在训练后期调整 lognorm 参数以提升学习效果。采样时使用 250 步 Euler 积分器,并采用 cfg interval 和 timestep shift 等技术提升采样性能。

具体应用场景

  1. 高质量图像生成:在需要生成高质量图像的领域,如数字艺术创作、广告设计、游戏开发等,该模型能够生成逼真、细节丰富的图像。例如,游戏开发者可以利用该模型快速生成游戏场景中的各种元素,如角色、道具、背景等,提高游戏开发效率;广告设计师可以根据客户需求生成吸引人的广告图片,增强广告效果。
  2. 图像编辑与增强:可用于图像编辑任务,如修复受损图像、提升图像分辨率、对图像进行风格转换等。例如,对于老旧照片的修复,模型可以根据周围像素信息和图像的整体结构,重建缺失或损坏的部分,使照片恢复清晰和完整;在图像风格转换方面,能够将普通照片转换为特定艺术风格的图像,满足用户多样化的审美需求。
  3. 计算机视觉研究中的数据增强:在计算机视觉任务的数据预处理阶段,用于扩充训练数据集,增加数据多样性,提高模型的泛化能力。例如,在目标检测、图像分类等任务中,通过生成多样化的图像样本,使模型能够学习到更广泛的特征表示,从而提升在实际应用中的性能表现。
0

评论0

没有账号?注册  忘记密码?