Stable Diffusion(KerasCV)

使用KerasCV的稳定扩散图像生成。Stable Diffusion是一个强大的文本 -> 图像模型,有Stability AI开源。虽然存在多种开源实现,可以轻松地根据文本提示创建图像,但KerasCV提供了一些优势:其中包括XLA编译和混合精度支持,他们共同实现最优的生成,使用KerasCV调用Stable Diffusion非常简单。我们传入一个字符串,通常称为提示,批量大小为3。模型能够生成三张令人惊艳的图片,正如提示所描述:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import time
import keras_cv
import keras
import matplotlib.pyplot as plt

# 首先,我们构建一个模型
model = keras_cv.models.StableDiffusion(img_width=512, img_height=512, jit_compile=False)

# 接下来,我们给它一个提示:
images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)
def plot_images(images):
plt.figure(figsize=(20, 20))
for i in range(len(images)):
ax = plt.subplot(1, len(images), i + 1)
plt.imshow(images[i])
plt.axis("off")


plot_images(images)

# 但这并不是该模型所能做的全部。让我们尝试一个更复杂的提示:
images = model.text_to_image(
"cute magical flying dog, fantasy art, "
"golden color, high quality, highly detailed, elegant, sharp focus, "
"concept art, character concepts, digital painting, mystery, adventure",
batch_size=3,
)
plot_images(images)

Stable Diffusion是怎么工作的?

要从潜在扩散文本 —> 图像,您需要添加一个关键特征:通过提示关键字控制生成的视觉内容的能力。这是通过“调节”来完成的,这是一种经典的深度学习技术,其中包括将一段文本的向量连接到噪声补丁,然后在{image:caption}的数据集上训练模型。这就产生了稳定扩散架构。 稳定扩散由三部分组成:

  • 文本编码器,可将您的提示转换为潜在向量。
  • 扩散模型,反复对64x64潜在图像块进行“去噪”。
  • 解码器,将最终的64x64潜在补丁转换为更高分辨率的512x512图像。

首先,您的文本提示由文本编码器投影到潜在向量空间中,这只是一个预训练的冻结语言模型。然后,将该提示向量连接到随机生成的噪声块,该噪声块通过扩散模型在一系列“步骤”上重复“去噪”(运行的步骤越多,图像就会越清晰、越好-默认值为50次)。最后,64x64潜在图像通过解码器发送,用高分辨率渲染它。

上图中有一个文本编码器将,可以将提示字符串转换潜在的向量,该向量连接到随机生成的噪声补丁。新的向量将通过扩散模型重复去噪,最后潜在图像通过解码器,将64 x 64图像块转换为更高分辨率的512 x 512图像。

总而言之,这是一个非常简单的系统 — Keras的实现包含在四个文件中,总共不到500行代码:

  • text_encoder.py: 87 LOC
  • diffusion_model.py: 181 LOC
  • decoder.py: 86 LOC
  • stable_diffusion.py: 106 LOC

但一旦你训练了数十亿张图片及其标题,这个相对简单的系统就开始看起来像魔法一样。 正如费曼对宇宙的评价:“宇宙并不复杂,只是很多!”

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
benchmark_result = []
start = time.time()
images = model.text_to_image(
"A cute otter in a rainbow whirlpool holding shells, watercolor",
batch_size=3,
)
end = time.time()
benchmark_result.append(["Standard", end - start])
plot_images(images)

print(f"Standard model: {(end - start):.2f} seconds")
keras.backend.clear_session() # Clear session to preserve memory.

# 50/50 ━━━━━━━━━━━━━━━━━━━━ 10s 209ms/step
# Standard model: 10.57 seconds

现在我们打开混合精度,使用float16精度执行计算,同时存储float32存储权重。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Warm up model to run graph tracing before benchmarking.
model.text_to_image("warming up the model", batch_size=3)

start = time.time()
images = model.text_to_image(
"a cute magical flying dog, fantasy art, "
"golden color, high quality, highly detailed, elegant, sharp focus, "
"concept art, character concepts, digital painting, mystery, adventure",
batch_size=3,
)
end = time.time()
benchmark_result.append(["Mixed Precision", end - start])
plot_images(images)

print(f"Mixed precision model: {(end - start):.2f} seconds")
keras.backend.clear_session()

# 50/50 ━━━━━━━━━━━━━━━━━━━━ 42s 132ms/step
# 50/50 ━━━━━━━━━━━━━━━━━━━━ 6s 129ms/step
# Mixed precision model: 6.65 seconds

这更快是因为NVIDIA GPU具有专门的FP16运算内核。其运行速度比FP32同类产品更快。接下来我们尝试一下XLA编译,我们可以通过再次构建模型时将jit_compile标志设置为true来做到这一点。让我们对XLA模型进行基准测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Set back to the default for benchmarking purposes.
keras.mixed_precision.set_global_policy("float32")
model = keras_cv.models.StableDiffusion(jit_compile=True)

start = time.time()
# Before we benchmark the model, we run inference once to make sure the TensorFlow
# graph has already been traced.
images = model.text_to_image("A cute otter in a rainbow whirlpool holding shells, watercolor",batch_size=3,)
end = time.time()

benchmark_result.append(["XLA", end - start])
plot_images(images)

print(f"With XLA: {(end - start):.2f} seconds")
keras.backend.clear_session()

# 50/50 ━━━━━━━━━━━━━━━━━━━━ 11s 210ms/step
# With XLA: 10.63 seconds

A100 GPU上,我们获得了大约2倍的加速。最后我们可以将所有内容放在一起,并打开混合精度和XLA编译,这次只花了大约6.66s

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
keras.mixed_precision.set_global_policy("mixed_float16")
model = keras_cv.models.StableDiffusion(jit_compile=True)

start = time.time()
images = model.text_to_image(
"A mysterious dark stranger visits the great pyramids of egypt, "
"high quality, highly detailed, elegant, sharp focus, "
"concept art, character concepts, digital painting",
batch_size=3,
)
end = time.time()
benchmark_result.append(["XLA + Mixed Precision", end - start])
plot_images(images)

print(f"XLA + mixed precision: {(end - start):.2f} seconds")

# 50/50 ━━━━━━━━━━━━━━━━━━━━ 6s 130ms/step
# XLA + mixed precision: 6.66 seconds

结论

KerasCV提供了最先进的稳定扩散实现 - 通过使用XLA和混合精度,它提供了截至20229月可用的最快的稳定扩散管道。