Add Stable Diffusion 3 Example (#2558)

* Add stable diffusion 3 example

Add get_qkv_linear to handle different dimensionality in linears

Add stable diffusion 3 example

Add use_quant_conv and use_post_quant_conv for vae in stable diffusion

adapt existing AutoEncoderKLConfig to the change

add forward_until_encoder_layer to ClipTextTransformer

rename sd3 config to sd3_medium in mmdit; minor clean-up

Enable flash-attn for mmdit impl when the feature is enabled.

Add sd3 example codebase

add document

crediting references

pass the cargo fmt test

pass the clippy test

* fix typos

* expose cfg_scale and time_shift as options

* Replace the sample image with JPG version. Change image output format accordingly.

* make meaningful error messages

* remove the tail-end assignment in sd3_vae_vb_rename

* remove the CUDA requirement

* use default_value in clap args

* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime

* resolve clippy errors and warnings

* use default_value_t

* Pin the web-sys dependency.

* Clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Czxck001
2024-10-13 13:08:40 -07:00
committed by GitHub
parent 0d96ec31e8
commit ca7cf5cb3b
16 changed files with 751 additions and 34 deletions

View File

@ -122,3 +122,6 @@ required-features = ["onnx"]
[[example]]
name = "colpali"
required-features = ["pdf2image"]
[[example]]
name = "stable-diffusion-3"

View File

@ -0,0 +1,54 @@
# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3 Medium
![](assets/stable-diffusion-3.jpg)
*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k*
Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture.
- [huggingface repo](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
- [research paper](https://arxiv.org/pdf/2403.03206)
- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium)
## Getting access to the weights
The weights of Stable Diffusion 3 Medium is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the [repo on HuggingFace Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium) to gain access to the weights for your HuggingFace account.
On the first run, the weights will be automatically downloaded from the Huggingface Hub. You might be prompted to configure a [Huggingface User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) on your computer if you haven't done that before. After the download, the weights will be [cached](https://huggingface.co/docs/datasets/en/cache) and remain accessible locally.
## Running the model
```shell
cargo run --example stable-diffusion-3 --release --features=cuda -- \
--height 1024 --width 1024 \
--prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k'
```
To display other options available,
```shell
cargo run --example stable-diffusion-3 --release --features=cuda -- --help
```
If GPU supports, Flash-Attention is a strongly recommended feature as it can greatly improve the speed of inference, as MMDiT is a transformer model heavily depends on attentions. To utilize [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) in the demo, you will need both `--features flash-attn` and `--use-flash-attn`.
```shell
cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- --use-flash-attn ...
```
## Performance Benchmark
Below benchmark is done by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds).
[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc).
System specs (Desktop PCIE 5 x8/x8 dual-GPU setup):
- Operating System: Ubuntu 23.10
- CPU: i9 12900K w/o overclocking.
- RAM: 64G dual-channel DDR5 @ 4800 MT/s
| Speed (iter/s) | w/o flash-attn | w/ flash-attn |
| -------------- | -------------- | ------------- |
| RTX 3090 Ti | 0.83 | 2.15 |
| RTX 4090 | 1.72 | 4.06 |

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

View File

@ -0,0 +1,201 @@
use anyhow::{Error as E, Ok, Result};
use candle::{DType, IndexOp, Module, Tensor, D};
use candle_transformers::models::{stable_diffusion, t5};
use tokenizers::tokenizer::Tokenizer;
struct ClipWithTokenizer {
clip: stable_diffusion::clip::ClipTextTransformer,
config: stable_diffusion::clip::Config,
tokenizer: Tokenizer,
max_position_embeddings: usize,
}
impl ClipWithTokenizer {
fn new(
vb: candle_nn::VarBuilder,
config: stable_diffusion::clip::Config,
tokenizer_path: &str,
max_position_embeddings: usize,
) -> Result<Self> {
let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?;
let path_buf = hf_hub::api::sync::Api::new()?
.model(tokenizer_path.to_string())
.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg(
"Failed to serialize huggingface PathBuf of CLIP tokenizer",
))?)
.map_err(E::msg)?;
Ok(Self {
clip,
config,
tokenizer,
max_position_embeddings,
})
}
fn encode_text_to_embedding(
&self,
prompt: &str,
device: &candle::Device,
) -> Result<(Tensor, Tensor)> {
let pad_id = match &self.config.pad_with {
Some(padding) => *self
.tokenizer
.get_vocab(true)
.get(padding.as_str())
.ok_or(E::msg("Failed to tokenize CLIP padding."))?,
None => *self
.tokenizer
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?,
};
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let eos_position = tokens.len() - 1;
while tokens.len() < self.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let (text_embeddings, text_embeddings_penultimate) = self
.clip
.forward_until_encoder_layer(&tokens, usize::MAX, -2)?;
let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?;
Ok((text_embeddings_penultimate, text_embeddings_pooled))
}
}
struct T5WithTokenizer {
t5: t5::T5EncoderModel,
tokenizer: Tokenizer,
max_position_embeddings: usize,
}
impl T5WithTokenizer {
fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result<Self> {
let api = hf_hub::api::sync::Api::new()?;
let repo = api.repo(hf_hub::Repo::with_revision(
"google/t5-v1_1-xxl".to_string(),
hf_hub::RepoType::Model,
"refs/pr/2".to_string(),
));
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let model = t5::T5EncoderModel::load(vb, &config)?;
let tokenizer_filename = api
.model("lmz/mt5-tokenizers".to_string())
.get("t5-v1_1-xxl.tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok(Self {
t5: model,
tokenizer,
max_position_embeddings,
})
}
fn encode_text_to_embedding(
&mut self,
prompt: &str,
device: &candle::Device,
) -> Result<Tensor> {
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.resize(self.max_position_embeddings, 0);
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let embeddings = self.t5.forward(&input_token_ids)?;
Ok(embeddings)
}
}
pub struct StableDiffusion3TripleClipWithTokenizer {
clip_l: ClipWithTokenizer,
clip_g: ClipWithTokenizer,
clip_g_text_projection: candle_nn::Linear,
t5: T5WithTokenizer,
}
impl StableDiffusion3TripleClipWithTokenizer {
pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> {
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
vb_fp16.pp("clip_l.transformer"),
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;
let clip_g = ClipWithTokenizer::new(
vb_fp16.pp("clip_g.transformer"),
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;
let text_projection = candle_nn::linear_no_bias(
1280,
1280,
vb_fp16.pp("clip_g.transformer.text_projection"),
)?;
// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
// This is a temporary workaround until the T5 implementation is updated to support fp16.
// Also see:
// https://github.com/huggingface/candle/issues/2480
// https://github.com/huggingface/candle/pull/2481
let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?;
Ok(Self {
clip_l,
clip_g,
clip_g_text_projection: text_projection,
t5,
})
}
pub fn encode_text_to_embedding(
&mut self,
prompt: &str,
device: &candle::Device,
) -> Result<(Tensor, Tensor)> {
let (clip_l_embeddings, clip_l_embeddings_pooled) =
self.clip_l.encode_text_to_embedding(prompt, device)?;
let (clip_g_embeddings, clip_g_embeddings_pooled) =
self.clip_g.encode_text_to_embedding(prompt, device)?;
let clip_g_embeddings_pooled = self
.clip_g_text_projection
.forward(&clip_g_embeddings_pooled.unsqueeze(0)?)?
.squeeze(0)?;
let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)?
.unsqueeze(0)?;
let clip_embeddings_concat = Tensor::cat(
&[&clip_l_embeddings, &clip_g_embeddings],
D::Minus1,
)?
.pad_with_zeros(D::Minus1, 0, 2048)?;
let t5_embeddings = self
.t5
.encode_text_to_embedding(prompt, device)?
.to_dtype(DType::F16)?;
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;
Ok((context, y))
}
}

View File

@ -0,0 +1,185 @@
mod clip;
mod sampling;
mod vae;
use candle::{DType, IndexOp, Tensor};
use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT};
use crate::clip::StableDiffusion3TripleClipWithTokenizer;
use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
use anyhow::{Ok, Result};
use clap::Parser;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The prompt to be used for image generation.
#[arg(
long,
default_value = "A cute rusty robot holding a candle torch in its hand, \
with glowing neon text \"LETS GO RUSTY\" displayed on its chest, \
bright background, high quality, 4k"
)]
prompt: String,
#[arg(long, default_value = "")]
uncond_prompt: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The CUDA device ID to use.
#[arg(long, default_value = "0")]
cuda_device_id: usize,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Use flash_attn to accelerate attention operation in the MMDiT.
#[arg(long)]
use_flash_attn: bool,
/// The height in pixels of the generated image.
#[arg(long, default_value_t = 1024)]
height: usize,
/// The width in pixels of the generated image.
#[arg(long, default_value_t = 1024)]
width: usize,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 28)]
num_inference_steps: usize,
// CFG scale.
#[arg(long, default_value_t = 4.0)]
cfg_scale: f64,
// Time shift factor (alpha).
#[arg(long, default_value_t = 3.0)]
time_shift: f64,
/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
}
fn main() -> Result<()> {
let args = Args::parse();
// Your main code here
run(args)
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args {
prompt,
uncond_prompt,
cpu,
cuda_device_id,
tracing,
use_flash_attn,
height,
width,
num_inference_steps,
cfg_scale,
time_shift,
seed,
} = args;
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
// TODO: Support and test on Metal.
let device = if cpu {
candle::Device::Cpu
} else {
candle::Device::cuda_if_available(cuda_device_id)?
};
let api = hf_hub::api::sync::Api::new()?;
let sai_repo = {
let name = "stabilityai/stable-diffusion-3-medium";
api.repo(hf_hub::Repo::model(name.to_string()))
};
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
let vb_fp16 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F16, &device)?
};
let (context, y) = {
let vb_fp32 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(
&[model_file.clone()],
DType::F32,
&device,
)?
};
let mut triple = StableDiffusion3TripleClipWithTokenizer::new(
vb_fp16.pp("text_encoders"),
vb_fp32.pp("text_encoders"),
)?;
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
let (context_uncond, y_uncond) =
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
(
Tensor::cat(&[context, context_uncond], 0)?,
Tensor::cat(&[y, y_uncond], 0)?,
)
};
let x = {
let mmdit = MMDiT::new(
&MMDiTConfig::sd3_medium(),
use_flash_attn,
vb_fp16.pp("model.diffusion_model"),
)?;
if let Some(seed) = seed {
device.set_seed(seed)?;
}
let start_time = std::time::Instant::now();
let x = sampling::euler_sample(
&mmdit,
&y,
&context,
num_inference_steps,
cfg_scale,
time_shift,
height,
width,
)?;
let dt = start_time.elapsed().as_secs_f32();
println!(
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
dt,
num_inference_steps as f32 / dt
);
x
};
let img = {
let vb_vae = vb_fp16
.clone()
.rename_f(sd3_vae_vb_rename)
.pp("first_stage_model");
let autoencoder = build_sd3_vae_autoencoder(vb_vae)?;
// Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)?
};
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
Ok(())
}

View File

@ -0,0 +1,55 @@
use anyhow::{Ok, Result};
use candle::{DType, Tensor};
use candle_transformers::models::flux;
use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function
#[allow(clippy::too_many_arguments)]
pub fn euler_sample(
mmdit: &MMDiT,
y: &Tensor,
context: &Tensor,
num_inference_steps: usize,
cfg_scale: f64,
time_shift: f64,
height: usize,
width: usize,
) -> Result<Tensor> {
let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
let sigmas = (0..=num_inference_steps)
.map(|x| x as f64 / num_inference_steps as f64)
.rev()
.map(|x| time_snr_shift(time_shift, x))
.collect::<Vec<f64>>();
for window in sigmas.windows(2) {
let (s_curr, s_prev) = match window {
[a, b] => (a, b),
_ => continue,
};
let timestep = (*s_curr) * 1000.0;
let noise_pred = mmdit.forward(
&Tensor::cat(&[x.clone(), x.clone()], 0)?,
&Tensor::full(timestep, (2,), x.device())?.contiguous()?,
y,
context,
)?;
x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?;
}
Ok(x)
}
// The "Resolution-dependent shifting of timestep schedules" recommended in the SD3 tech report paper
// https://arxiv.org/pdf/2403.03206
// Following the implementation in ComfyUI:
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/
// comfy/model_sampling.py#L181
fn time_snr_shift(alpha: f64, t: f64) -> f64 {
alpha * t / (1.0 + (alpha - 1.0) * t)
}
fn apply_cfg(cfg_scale: f64, noise_pred: &Tensor) -> Result<Tensor> {
Ok(((cfg_scale * noise_pred.narrow(0, 0, 1)?)?
- ((cfg_scale - 1.0) * noise_pred.narrow(0, 1, 1)?)?)?)
}

View File

@ -0,0 +1,93 @@
use anyhow::{Ok, Result};
use candle_transformers::models::stable_diffusion::vae;
pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result<vae::AutoEncoderKL> {
let config = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 16,
norm_num_groups: 32,
use_quant_conv: false,
use_post_quant_conv: false,
};
Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?)
}
pub fn sd3_vae_vb_rename(name: &str) -> String {
let parts: Vec<&str> = name.split('.').collect();
let mut result = Vec::new();
let mut i = 0;
while i < parts.len() {
match parts[i] {
"down_blocks" => {
result.push("down");
}
"mid_block" => {
result.push("mid");
}
"up_blocks" => {
result.push("up");
match parts[i + 1] {
// Reverse the order of up_blocks.
"0" => result.push("3"),
"1" => result.push("2"),
"2" => result.push("1"),
"3" => result.push("0"),
_ => {}
}
i += 1; // Skip the number after up_blocks.
}
"resnets" => {
if i > 0 && parts[i - 1] == "mid_block" {
match parts[i + 1] {
"0" => result.push("block_1"),
"1" => result.push("block_2"),
_ => {}
}
i += 1; // Skip the number after resnets.
} else {
result.push("block");
}
}
"downsamplers" => {
result.push("downsample");
i += 1; // Skip the 0 after downsamplers.
}
"conv_shortcut" => {
result.push("nin_shortcut");
}
"attentions" => {
if parts[i + 1] == "0" {
result.push("attn_1")
}
i += 1; // Skip the number after attentions.
}
"group_norm" => {
result.push("norm");
}
"query" => {
result.push("q");
}
"key" => {
result.push("k");
}
"value" => {
result.push("v");
}
"proj_attn" => {
result.push("proj_out");
}
"conv_norm_out" => {
result.push("norm_out");
}
"upsamplers" => {
result.push("upsample");
i += 1; // Skip the 0 after upsamplers.
}
part => result.push(part),
}
i += 1;
}
result.join(".")
}