在图像生成领域,DiT(Diffusion Transformer)架构凭借其卓越的表现成为前沿技术。然而,该架构的核心——用于建模令牌间关系的注意力机制,由于其计算复杂度为二次方,导致在处理高分辨率图像时面临显著延迟的问题。为了突破这一瓶颈,研究人员致力于开发一种线性复杂度的注意力机制,以提高效率而不牺牲质量。
线性注意力机制的关键因素
新加坡国立大学的研究团队对现有的高效注意力机制进行了深入分析,总结出四个对于成功实现预训练DiT线性化至关重要的关键要素:
- 局部性:限制特征交互到每个查询令牌周围的局部区域。
- 公式一致性:确保新机制与原机制在数学表达上保持一致。
- 高秩注意力图:维持注意力分布的丰富性和多样性。
- 特征完整性:保留输入信息的完整性和细节。
基于这些发现,研究者们提出了CLEAR(Classical Local Attention with Reduced complexity),一种类卷积局部注意力策略。它通过将特征交互限制在局部窗口内,有效地降低了计算复杂度至线性级别,提升在高分辨率图像生成任务中的效率。通过这种方法,模型在生成8K分辨率图像时,能够将注意力计算减少99.5%,并将生成速度提高6.3倍。
主要功能:
- 提高效率:CLEAR通过将DiTs的复杂度从二次降低到线性,显著提高了高分辨率图像生成的效率。
- 保持性能:尽管复杂度降低,但CLEAR能够保持与原始DiTs相当的图像生成质量。
主要特点:
- 局部性(Locality):CLEAR限制每个查询令牌(query token)只与局部窗口内的键值令牌(key-value tokens)进行交互。
- 公式一致性(Formulation Consistency):CLEAR保持了与原始注意力机制相同的softmax基础公式。
- 高秩注意力图(High-Rank Attention Maps):CLEAR生成的注意力图能够捕捉复杂的令牌间关系。
- 特征完整性(Feature Integrity):CLEAR保留了原始查询、键和值特征的完整性,而不是压缩版本。
工作原理:
CLEAR通过引入一种类似卷积的局部注意力策略,每个查询只与预定义距离内的键值令牌进行交互。由于与每个查询交互的键值令牌数量是固定的,因此DiT的复杂度与图像分辨率呈线性关系。CLEAR使用圆形窗口来确定局部性,这意味着在每个查询的欧几里得距离内小于预定义半径的键值令牌将被考虑在内。
实验结果
实验结果显示,仅需对10,000个自动生成的样本进行10,000次迭代微调,即可成功地将预训练DiT的知识转移到具有线性复杂度的学生模型中。这不仅使得生成的图像质量与教师模型相媲美,还大幅减少了99.5%的注意力计算量,并将8K分辨率图像的生成速度提高了6.3倍。
此外,蒸馏注意力层展现出跨不同模型和插件的零样本泛化能力,以及对多GPU并行推理的更佳支持,进一步增强了其实用性。
支持的模型与应用
研究人员已经发布了一系列名为FLUX-1.dev的线性化变体,它们拥有不同的局部窗口大小。实验表明,较小的窗口(如8x8)可能会导致重复图案的出现。为了解决这个问题,在某些变体中加入了降采样的键值令牌,以便在局部令牌之外进行注意力交互。
推理指南
为了方便用户比较线性化FLUX与原始模型的性能,我们提供了inference_t2i.ipynb
笔记本文件。如果您希望利用CLEAR加速高分辨率图像生成,请使用inference_t2i_highres.ipynb
,其中采用了SDEdit策略,即先生成低分辨率图像,再逐步放大至所需尺寸。
在进行推理时,您可以通过配置down_factor
和window_size
来选择不同的CLEAR变体。若不希望包含降采样的键值令牌,可以设置down_factor=1
。首次使用时,模型会自动下载至ckpt
目录。
对于高分辨率图像生成,建议使用配备有48GB显存的GPU卡,以确保最佳性能。
评论0