mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example Add get_qkv_linear to handle different dimensionality in linears Add stable diffusion 3 example Add use_quant_conv and use_post_quant_conv for vae in stable diffusion adapt existing AutoEncoderKLConfig to the change add forward_until_encoder_layer to ClipTextTransformer rename sd3 config to sd3_medium in mmdit; minor clean-up Enable flash-attn for mmdit impl when the feature is enabled. Add sd3 example codebase add document crediting references pass the cargo fmt test pass the clippy test * fix typos * expose cfg_scale and time_shift as options * Replace the sample image with JPG version. Change image output format accordingly. * make meaningful error messages * remove the tail-end assignment in sd3_vae_vb_rename * remove the CUDA requirement * use default_value in clap args * add use_flash_attn to turn on/off flash-attn for MMDiT at runtime * resolve clippy errors and warnings * use default_value_t * Pin the web-sys dependency. * Clippy fix. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -122,3 +122,6 @@ required-features = ["onnx"]
|
||||
[[example]]
|
||||
name = "colpali"
|
||||
required-features = ["pdf2image"]
|
||||
|
||||
[[example]]
|
||||
name = "stable-diffusion-3"
|
54
candle-examples/examples/stable-diffusion-3/README.md
Normal file
54
candle-examples/examples/stable-diffusion-3/README.md
Normal file
@ -0,0 +1,54 @@
|
||||
# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3 Medium
|
||||
|
||||

|
||||
|
||||
*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k*
|
||||
|
||||
Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture.
|
||||
|
||||
- [huggingface repo](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
|
||||
- [research paper](https://arxiv.org/pdf/2403.03206)
|
||||
- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium)
|
||||
|
||||
## Getting access to the weights
|
||||
|
||||
The weights of Stable Diffusion 3 Medium is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the [repo on HuggingFace Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium) to gain access to the weights for your HuggingFace account.
|
||||
|
||||
On the first run, the weights will be automatically downloaded from the Huggingface Hub. You might be prompted to configure a [Huggingface User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) on your computer if you haven't done that before. After the download, the weights will be [cached](https://huggingface.co/docs/datasets/en/cache) and remain accessible locally.
|
||||
|
||||
## Running the model
|
||||
|
||||
```shell
|
||||
cargo run --example stable-diffusion-3 --release --features=cuda -- \
|
||||
--height 1024 --width 1024 \
|
||||
--prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k'
|
||||
```
|
||||
|
||||
To display other options available,
|
||||
|
||||
```shell
|
||||
cargo run --example stable-diffusion-3 --release --features=cuda -- --help
|
||||
```
|
||||
|
||||
If GPU supports, Flash-Attention is a strongly recommended feature as it can greatly improve the speed of inference, as MMDiT is a transformer model heavily depends on attentions. To utilize [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) in the demo, you will need both `--features flash-attn` and `--use-flash-attn`.
|
||||
|
||||
```shell
|
||||
cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- --use-flash-attn ...
|
||||
```
|
||||
|
||||
## Performance Benchmark
|
||||
|
||||
Below benchmark is done by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds).
|
||||
|
||||
[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc).
|
||||
|
||||
System specs (Desktop PCIE 5 x8/x8 dual-GPU setup):
|
||||
|
||||
- Operating System: Ubuntu 23.10
|
||||
- CPU: i9 12900K w/o overclocking.
|
||||
- RAM: 64G dual-channel DDR5 @ 4800 MT/s
|
||||
|
||||
| Speed (iter/s) | w/o flash-attn | w/ flash-attn |
|
||||
| -------------- | -------------- | ------------- |
|
||||
| RTX 3090 Ti | 0.83 | 2.15 |
|
||||
| RTX 4090 | 1.72 | 4.06 |
|
Binary file not shown.
After Width: | Height: | Size: 81 KiB |
201
candle-examples/examples/stable-diffusion-3/clip.rs
Normal file
201
candle-examples/examples/stable-diffusion-3/clip.rs
Normal file
@ -0,0 +1,201 @@
|
||||
use anyhow::{Error as E, Ok, Result};
|
||||
use candle::{DType, IndexOp, Module, Tensor, D};
|
||||
use candle_transformers::models::{stable_diffusion, t5};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
struct ClipWithTokenizer {
|
||||
clip: stable_diffusion::clip::ClipTextTransformer,
|
||||
config: stable_diffusion::clip::Config,
|
||||
tokenizer: Tokenizer,
|
||||
max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
impl ClipWithTokenizer {
|
||||
fn new(
|
||||
vb: candle_nn::VarBuilder,
|
||||
config: stable_diffusion::clip::Config,
|
||||
tokenizer_path: &str,
|
||||
max_position_embeddings: usize,
|
||||
) -> Result<Self> {
|
||||
let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?;
|
||||
let path_buf = hf_hub::api::sync::Api::new()?
|
||||
.model(tokenizer_path.to_string())
|
||||
.get("tokenizer.json")?;
|
||||
let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg(
|
||||
"Failed to serialize huggingface PathBuf of CLIP tokenizer",
|
||||
))?)
|
||||
.map_err(E::msg)?;
|
||||
Ok(Self {
|
||||
clip,
|
||||
config,
|
||||
tokenizer,
|
||||
max_position_embeddings,
|
||||
})
|
||||
}
|
||||
|
||||
fn encode_text_to_embedding(
|
||||
&self,
|
||||
prompt: &str,
|
||||
device: &candle::Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let pad_id = match &self.config.pad_with {
|
||||
Some(padding) => *self
|
||||
.tokenizer
|
||||
.get_vocab(true)
|
||||
.get(padding.as_str())
|
||||
.ok_or(E::msg("Failed to tokenize CLIP padding."))?,
|
||||
None => *self
|
||||
.tokenizer
|
||||
.get_vocab(true)
|
||||
.get("<|endoftext|>")
|
||||
.ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?,
|
||||
};
|
||||
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let eos_position = tokens.len() - 1;
|
||||
|
||||
while tokens.len() < self.max_position_embeddings {
|
||||
tokens.push(pad_id)
|
||||
}
|
||||
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
let (text_embeddings, text_embeddings_penultimate) = self
|
||||
.clip
|
||||
.forward_until_encoder_layer(&tokens, usize::MAX, -2)?;
|
||||
let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?;
|
||||
|
||||
Ok((text_embeddings_penultimate, text_embeddings_pooled))
|
||||
}
|
||||
}
|
||||
|
||||
struct T5WithTokenizer {
|
||||
t5: t5::T5EncoderModel,
|
||||
tokenizer: Tokenizer,
|
||||
max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
impl T5WithTokenizer {
|
||||
fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result<Self> {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||
"google/t5-v1_1-xxl".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/2".to_string(),
|
||||
));
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: t5::Config = serde_json::from_str(&config)?;
|
||||
let model = t5::T5EncoderModel::load(vb, &config)?;
|
||||
|
||||
let tokenizer_filename = api
|
||||
.model("lmz/mt5-tokenizers".to_string())
|
||||
.get("t5-v1_1-xxl.tokenizer.json")?;
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
Ok(Self {
|
||||
t5: model,
|
||||
tokenizer,
|
||||
max_position_embeddings,
|
||||
})
|
||||
}
|
||||
|
||||
fn encode_text_to_embedding(
|
||||
&mut self,
|
||||
prompt: &str,
|
||||
device: &candle::Device,
|
||||
) -> Result<Tensor> {
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
tokens.resize(self.max_position_embeddings, 0);
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let embeddings = self.t5.forward(&input_token_ids)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StableDiffusion3TripleClipWithTokenizer {
|
||||
clip_l: ClipWithTokenizer,
|
||||
clip_g: ClipWithTokenizer,
|
||||
clip_g_text_projection: candle_nn::Linear,
|
||||
t5: T5WithTokenizer,
|
||||
}
|
||||
|
||||
impl StableDiffusion3TripleClipWithTokenizer {
|
||||
pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> {
|
||||
let max_position_embeddings = 77usize;
|
||||
let clip_l = ClipWithTokenizer::new(
|
||||
vb_fp16.pp("clip_l.transformer"),
|
||||
stable_diffusion::clip::Config::sdxl(),
|
||||
"openai/clip-vit-large-patch14",
|
||||
max_position_embeddings,
|
||||
)?;
|
||||
|
||||
let clip_g = ClipWithTokenizer::new(
|
||||
vb_fp16.pp("clip_g.transformer"),
|
||||
stable_diffusion::clip::Config::sdxl2(),
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
||||
max_position_embeddings,
|
||||
)?;
|
||||
|
||||
let text_projection = candle_nn::linear_no_bias(
|
||||
1280,
|
||||
1280,
|
||||
vb_fp16.pp("clip_g.transformer.text_projection"),
|
||||
)?;
|
||||
|
||||
// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
|
||||
// This is a temporary workaround until the T5 implementation is updated to support fp16.
|
||||
// Also see:
|
||||
// https://github.com/huggingface/candle/issues/2480
|
||||
// https://github.com/huggingface/candle/pull/2481
|
||||
let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?;
|
||||
|
||||
Ok(Self {
|
||||
clip_l,
|
||||
clip_g,
|
||||
clip_g_text_projection: text_projection,
|
||||
t5,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode_text_to_embedding(
|
||||
&mut self,
|
||||
prompt: &str,
|
||||
device: &candle::Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (clip_l_embeddings, clip_l_embeddings_pooled) =
|
||||
self.clip_l.encode_text_to_embedding(prompt, device)?;
|
||||
let (clip_g_embeddings, clip_g_embeddings_pooled) =
|
||||
self.clip_g.encode_text_to_embedding(prompt, device)?;
|
||||
|
||||
let clip_g_embeddings_pooled = self
|
||||
.clip_g_text_projection
|
||||
.forward(&clip_g_embeddings_pooled.unsqueeze(0)?)?
|
||||
.squeeze(0)?;
|
||||
|
||||
let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)?
|
||||
.unsqueeze(0)?;
|
||||
let clip_embeddings_concat = Tensor::cat(
|
||||
&[&clip_l_embeddings, &clip_g_embeddings],
|
||||
D::Minus1,
|
||||
)?
|
||||
.pad_with_zeros(D::Minus1, 0, 2048)?;
|
||||
|
||||
let t5_embeddings = self
|
||||
.t5
|
||||
.encode_text_to_embedding(prompt, device)?
|
||||
.to_dtype(DType::F16)?;
|
||||
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;
|
||||
|
||||
Ok((context, y))
|
||||
}
|
||||
}
|
185
candle-examples/examples/stable-diffusion-3/main.rs
Normal file
185
candle-examples/examples/stable-diffusion-3/main.rs
Normal file
@ -0,0 +1,185 @@
|
||||
mod clip;
|
||||
mod sampling;
|
||||
mod vae;
|
||||
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT};
|
||||
|
||||
use crate::clip::StableDiffusion3TripleClipWithTokenizer;
|
||||
use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
|
||||
|
||||
use anyhow::{Ok, Result};
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The prompt to be used for image generation.
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "A cute rusty robot holding a candle torch in its hand, \
|
||||
with glowing neon text \"LETS GO RUSTY\" displayed on its chest, \
|
||||
bright background, high quality, 4k"
|
||||
)]
|
||||
prompt: String,
|
||||
|
||||
#[arg(long, default_value = "")]
|
||||
uncond_prompt: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The CUDA device ID to use.
|
||||
#[arg(long, default_value = "0")]
|
||||
cuda_device_id: usize,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Use flash_attn to accelerate attention operation in the MMDiT.
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// The height in pixels of the generated image.
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
height: usize,
|
||||
|
||||
/// The width in pixels of the generated image.
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
width: usize,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 28)]
|
||||
num_inference_steps: usize,
|
||||
|
||||
// CFG scale.
|
||||
#[arg(long, default_value_t = 4.0)]
|
||||
cfg_scale: f64,
|
||||
|
||||
// Time shift factor (alpha).
|
||||
#[arg(long, default_value_t = 3.0)]
|
||||
time_shift: f64,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long)]
|
||||
seed: Option<u64>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
// Your main code here
|
||||
run(args)
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let Args {
|
||||
prompt,
|
||||
uncond_prompt,
|
||||
cpu,
|
||||
cuda_device_id,
|
||||
tracing,
|
||||
use_flash_attn,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
cfg_scale,
|
||||
time_shift,
|
||||
seed,
|
||||
} = args;
|
||||
|
||||
let _guard = if tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// TODO: Support and test on Metal.
|
||||
let device = if cpu {
|
||||
candle::Device::Cpu
|
||||
} else {
|
||||
candle::Device::cuda_if_available(cuda_device_id)?
|
||||
};
|
||||
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let sai_repo = {
|
||||
let name = "stabilityai/stable-diffusion-3-medium";
|
||||
api.repo(hf_hub::Repo::model(name.to_string()))
|
||||
};
|
||||
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
|
||||
let vb_fp16 = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F16, &device)?
|
||||
};
|
||||
|
||||
let (context, y) = {
|
||||
let vb_fp32 = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(
|
||||
&[model_file.clone()],
|
||||
DType::F32,
|
||||
&device,
|
||||
)?
|
||||
};
|
||||
let mut triple = StableDiffusion3TripleClipWithTokenizer::new(
|
||||
vb_fp16.pp("text_encoders"),
|
||||
vb_fp32.pp("text_encoders"),
|
||||
)?;
|
||||
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
|
||||
let (context_uncond, y_uncond) =
|
||||
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
|
||||
(
|
||||
Tensor::cat(&[context, context_uncond], 0)?,
|
||||
Tensor::cat(&[y, y_uncond], 0)?,
|
||||
)
|
||||
};
|
||||
|
||||
let x = {
|
||||
let mmdit = MMDiT::new(
|
||||
&MMDiTConfig::sd3_medium(),
|
||||
use_flash_attn,
|
||||
vb_fp16.pp("model.diffusion_model"),
|
||||
)?;
|
||||
|
||||
if let Some(seed) = seed {
|
||||
device.set_seed(seed)?;
|
||||
}
|
||||
let start_time = std::time::Instant::now();
|
||||
let x = sampling::euler_sample(
|
||||
&mmdit,
|
||||
&y,
|
||||
&context,
|
||||
num_inference_steps,
|
||||
cfg_scale,
|
||||
time_shift,
|
||||
height,
|
||||
width,
|
||||
)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!(
|
||||
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
|
||||
dt,
|
||||
num_inference_steps as f32 / dt
|
||||
);
|
||||
x
|
||||
};
|
||||
|
||||
let img = {
|
||||
let vb_vae = vb_fp16
|
||||
.clone()
|
||||
.rename_f(sd3_vae_vb_rename)
|
||||
.pp("first_stage_model");
|
||||
let autoencoder = build_sd3_vae_autoencoder(vb_vae)?;
|
||||
|
||||
// Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
|
||||
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
|
||||
autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)?
|
||||
};
|
||||
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
|
||||
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
|
||||
Ok(())
|
||||
}
|
55
candle-examples/examples/stable-diffusion-3/sampling.rs
Normal file
55
candle-examples/examples/stable-diffusion-3/sampling.rs
Normal file
@ -0,0 +1,55 @@
|
||||
use anyhow::{Ok, Result};
|
||||
use candle::{DType, Tensor};
|
||||
|
||||
use candle_transformers::models::flux;
|
||||
use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn euler_sample(
|
||||
mmdit: &MMDiT,
|
||||
y: &Tensor,
|
||||
context: &Tensor,
|
||||
num_inference_steps: usize,
|
||||
cfg_scale: f64,
|
||||
time_shift: f64,
|
||||
height: usize,
|
||||
width: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
|
||||
let sigmas = (0..=num_inference_steps)
|
||||
.map(|x| x as f64 / num_inference_steps as f64)
|
||||
.rev()
|
||||
.map(|x| time_snr_shift(time_shift, x))
|
||||
.collect::<Vec<f64>>();
|
||||
|
||||
for window in sigmas.windows(2) {
|
||||
let (s_curr, s_prev) = match window {
|
||||
[a, b] => (a, b),
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let timestep = (*s_curr) * 1000.0;
|
||||
let noise_pred = mmdit.forward(
|
||||
&Tensor::cat(&[x.clone(), x.clone()], 0)?,
|
||||
&Tensor::full(timestep, (2,), x.device())?.contiguous()?,
|
||||
y,
|
||||
context,
|
||||
)?;
|
||||
x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?;
|
||||
}
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
// The "Resolution-dependent shifting of timestep schedules" recommended in the SD3 tech report paper
|
||||
// https://arxiv.org/pdf/2403.03206
|
||||
// Following the implementation in ComfyUI:
|
||||
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/
|
||||
// comfy/model_sampling.py#L181
|
||||
fn time_snr_shift(alpha: f64, t: f64) -> f64 {
|
||||
alpha * t / (1.0 + (alpha - 1.0) * t)
|
||||
}
|
||||
|
||||
fn apply_cfg(cfg_scale: f64, noise_pred: &Tensor) -> Result<Tensor> {
|
||||
Ok(((cfg_scale * noise_pred.narrow(0, 0, 1)?)?
|
||||
- ((cfg_scale - 1.0) * noise_pred.narrow(0, 1, 1)?)?)?)
|
||||
}
|
93
candle-examples/examples/stable-diffusion-3/vae.rs
Normal file
93
candle-examples/examples/stable-diffusion-3/vae.rs
Normal file
@ -0,0 +1,93 @@
|
||||
use anyhow::{Ok, Result};
|
||||
use candle_transformers::models::stable_diffusion::vae;
|
||||
|
||||
pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result<vae::AutoEncoderKL> {
|
||||
let config = vae::AutoEncoderKLConfig {
|
||||
block_out_channels: vec![128, 256, 512, 512],
|
||||
layers_per_block: 2,
|
||||
latent_channels: 16,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: false,
|
||||
use_post_quant_conv: false,
|
||||
};
|
||||
Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?)
|
||||
}
|
||||
|
||||
pub fn sd3_vae_vb_rename(name: &str) -> String {
|
||||
let parts: Vec<&str> = name.split('.').collect();
|
||||
let mut result = Vec::new();
|
||||
let mut i = 0;
|
||||
|
||||
while i < parts.len() {
|
||||
match parts[i] {
|
||||
"down_blocks" => {
|
||||
result.push("down");
|
||||
}
|
||||
"mid_block" => {
|
||||
result.push("mid");
|
||||
}
|
||||
"up_blocks" => {
|
||||
result.push("up");
|
||||
match parts[i + 1] {
|
||||
// Reverse the order of up_blocks.
|
||||
"0" => result.push("3"),
|
||||
"1" => result.push("2"),
|
||||
"2" => result.push("1"),
|
||||
"3" => result.push("0"),
|
||||
_ => {}
|
||||
}
|
||||
i += 1; // Skip the number after up_blocks.
|
||||
}
|
||||
"resnets" => {
|
||||
if i > 0 && parts[i - 1] == "mid_block" {
|
||||
match parts[i + 1] {
|
||||
"0" => result.push("block_1"),
|
||||
"1" => result.push("block_2"),
|
||||
_ => {}
|
||||
}
|
||||
i += 1; // Skip the number after resnets.
|
||||
} else {
|
||||
result.push("block");
|
||||
}
|
||||
}
|
||||
"downsamplers" => {
|
||||
result.push("downsample");
|
||||
i += 1; // Skip the 0 after downsamplers.
|
||||
}
|
||||
"conv_shortcut" => {
|
||||
result.push("nin_shortcut");
|
||||
}
|
||||
"attentions" => {
|
||||
if parts[i + 1] == "0" {
|
||||
result.push("attn_1")
|
||||
}
|
||||
i += 1; // Skip the number after attentions.
|
||||
}
|
||||
"group_norm" => {
|
||||
result.push("norm");
|
||||
}
|
||||
"query" => {
|
||||
result.push("q");
|
||||
}
|
||||
"key" => {
|
||||
result.push("k");
|
||||
}
|
||||
"value" => {
|
||||
result.push("v");
|
||||
}
|
||||
"proj_attn" => {
|
||||
result.push("proj_out");
|
||||
}
|
||||
"conv_norm_out" => {
|
||||
result.push("norm_out");
|
||||
}
|
||||
"upsamplers" => {
|
||||
result.push("upsample");
|
||||
i += 1; // Skip the 0 after upsamplers.
|
||||
}
|
||||
part => result.push(part),
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
result.join(".")
|
||||
}
|
@ -194,10 +194,16 @@ pub struct JointBlock {
|
||||
x_block: DiTBlock,
|
||||
context_block: DiTBlock,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl JointBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
|
||||
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
|
||||
|
||||
@ -205,13 +211,15 @@ impl JointBlock {
|
||||
x_block,
|
||||
context_block,
|
||||
num_heads,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
|
||||
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
|
||||
let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
|
||||
let (context_attn, x_attn) =
|
||||
joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
|
||||
let context_out =
|
||||
self.context_block
|
||||
.post_attention(&context_attn, context, &context_interm)?;
|
||||
@ -224,16 +232,23 @@ pub struct ContextQkvOnlyJointBlock {
|
||||
x_block: DiTBlock,
|
||||
context_block: QkvOnlyDiTBlock,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl ContextQkvOnlyJointBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
|
||||
let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
|
||||
Ok(Self {
|
||||
x_block,
|
||||
context_block,
|
||||
num_heads,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
@ -241,7 +256,7 @@ impl ContextQkvOnlyJointBlock {
|
||||
let context_qkv = self.context_block.pre_attention(context, c)?;
|
||||
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
|
||||
|
||||
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
|
||||
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
|
||||
|
||||
let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
|
||||
Ok(x_out)
|
||||
@ -266,7 +281,28 @@ fn flash_compatible_attention(
|
||||
attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2)
|
||||
}
|
||||
|
||||
fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> {
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
fn joint_attn(
|
||||
context_qkv: &Qkv,
|
||||
x_qkv: &Qkv,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let qkv = Qkv {
|
||||
q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?,
|
||||
k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?,
|
||||
@ -282,8 +318,12 @@ fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tenso
|
||||
|
||||
let headdim = qkv.q.dim(D::Minus1)?;
|
||||
let softmax_scale = 1.0 / (headdim as f64).sqrt();
|
||||
// let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?;
|
||||
let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?;
|
||||
|
||||
let attn = if use_flash_attn {
|
||||
flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
|
||||
} else {
|
||||
flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
|
||||
};
|
||||
|
||||
let attn = attn.reshape((batch_size, seqlen, ()))?;
|
||||
let context_qkv_seqlen = context_qkv.q.dim(1)?;
|
||||
|
@ -23,7 +23,7 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn sd3() -> Self {
|
||||
pub fn sd3_medium() -> Self {
|
||||
Self {
|
||||
patch_size: 2,
|
||||
in_channels: 16,
|
||||
@ -49,7 +49,7 @@ pub struct MMDiT {
|
||||
}
|
||||
|
||||
impl MMDiT {
|
||||
pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, use_flash_attn: bool, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.head_size * cfg.depth;
|
||||
let core = MMDiTCore::new(
|
||||
cfg.depth,
|
||||
@ -57,6 +57,7 @@ impl MMDiT {
|
||||
cfg.depth,
|
||||
cfg.patch_size,
|
||||
cfg.out_channels,
|
||||
use_flash_attn,
|
||||
vb.clone(),
|
||||
)?;
|
||||
let patch_embedder = PatchEmbedder::new(
|
||||
@ -135,6 +136,7 @@ impl MMDiTCore {
|
||||
num_heads: usize,
|
||||
patch_size: usize,
|
||||
out_channels: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mut joint_blocks = Vec::with_capacity(depth - 1);
|
||||
@ -142,6 +144,7 @@ impl MMDiTCore {
|
||||
joint_blocks.push(JointBlock::new(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
use_flash_attn,
|
||||
vb.pp(format!("joint_blocks.{}", i)),
|
||||
)?);
|
||||
}
|
||||
@ -151,6 +154,7 @@ impl MMDiTCore {
|
||||
context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
use_flash_attn,
|
||||
vb.pp(format!("joint_blocks.{}", depth - 1)),
|
||||
)?,
|
||||
final_layer: FinalLayer::new(
|
||||
|
@ -42,7 +42,6 @@ pub struct QkvOnlyAttnProjections {
|
||||
|
||||
impl QkvOnlyAttnProjections {
|
||||
pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
// {'dim': 1536, 'num_heads': 24}
|
||||
let head_dim = dim / num_heads;
|
||||
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
|
||||
Ok(Self { qkv, head_dim })
|
||||
|
@ -467,6 +467,24 @@ pub struct AttentionBlock {
|
||||
config: AttentionBlockConfig,
|
||||
}
|
||||
|
||||
// In the .safetensor weights of official Stable Diffusion 3 Medium Huggingface repo
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-3-medium
|
||||
// Linear layer may use a different dimension for the weight in the linear, which is
|
||||
// incompatible with the current implementation of the nn::linear constructor.
|
||||
// This is a workaround to handle the different dimensions.
|
||||
fn get_qkv_linear(channels: usize, vs: nn::VarBuilder) -> Result<nn::Linear> {
|
||||
match vs.get((channels, channels), "weight") {
|
||||
Ok(_) => nn::linear(channels, channels, vs),
|
||||
Err(_) => {
|
||||
let weight = vs
|
||||
.get((channels, channels, 1, 1), "weight")?
|
||||
.reshape((channels, channels))?;
|
||||
let bias = vs.get((channels,), "bias")?;
|
||||
Ok(nn::Linear::new(weight, Some(bias)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AttentionBlock {
|
||||
pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
|
||||
let num_head_channels = config.num_head_channels.unwrap_or(channels);
|
||||
@ -478,10 +496,10 @@ impl AttentionBlock {
|
||||
} else {
|
||||
("query", "key", "value", "proj_attn")
|
||||
};
|
||||
let query = nn::linear(channels, channels, vs.pp(q_path))?;
|
||||
let key = nn::linear(channels, channels, vs.pp(k_path))?;
|
||||
let value = nn::linear(channels, channels, vs.pp(v_path))?;
|
||||
let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?;
|
||||
let query = get_qkv_linear(channels, vs.pp(q_path))?;
|
||||
let key = get_qkv_linear(channels, vs.pp(k_path))?;
|
||||
let value = get_qkv_linear(channels, vs.pp(v_path))?;
|
||||
let proj_attn = get_qkv_linear(channels, vs.pp(out_path))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
|
||||
Ok(Self {
|
||||
group_norm,
|
||||
|
@ -388,6 +388,37 @@ impl ClipTextTransformer {
|
||||
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
|
||||
self.final_layer_norm.forward(&xs)
|
||||
}
|
||||
|
||||
pub fn forward_until_encoder_layer(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
mask_after: usize,
|
||||
until_layer: isize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (bsz, seq_len) = xs.dims2()?;
|
||||
let xs = self.embeddings.forward(xs)?;
|
||||
let causal_attention_mask =
|
||||
Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
|
||||
|
||||
let mut xs = xs.clone();
|
||||
let mut intermediate = xs.clone();
|
||||
|
||||
// Modified encoder.forward that returns the intermediate tensor along with final output.
|
||||
let until_layer = if until_layer < 0 {
|
||||
self.encoder.layers.len() as isize + until_layer
|
||||
} else {
|
||||
until_layer
|
||||
} as usize;
|
||||
|
||||
for (layer_id, layer) in self.encoder.layers.iter().enumerate() {
|
||||
xs = layer.forward(&xs, &causal_attention_mask)?;
|
||||
if layer_id == until_layer {
|
||||
intermediate = xs.clone();
|
||||
}
|
||||
}
|
||||
|
||||
Ok((self.final_layer_norm.forward(&xs)?, intermediate))
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ClipTextTransformer {
|
||||
|
@ -65,6 +65,8 @@ impl StableDiffusionConfig {
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: true,
|
||||
use_post_quant_conv: true,
|
||||
};
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
@ -133,6 +135,8 @@ impl StableDiffusionConfig {
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: true,
|
||||
use_post_quant_conv: true,
|
||||
};
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
prediction_type,
|
||||
@ -214,6 +218,8 @@ impl StableDiffusionConfig {
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: true,
|
||||
use_post_quant_conv: true,
|
||||
};
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
prediction_type,
|
||||
@ -281,6 +287,8 @@ impl StableDiffusionConfig {
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: true,
|
||||
use_post_quant_conv: true,
|
||||
};
|
||||
let scheduler = Arc::new(
|
||||
euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig {
|
||||
@ -378,6 +386,8 @@ impl StableDiffusionConfig {
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: true,
|
||||
use_post_quant_conv: true,
|
||||
};
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
..Default::default()
|
||||
|
@ -275,6 +275,8 @@ pub struct AutoEncoderKLConfig {
|
||||
pub layers_per_block: usize,
|
||||
pub latent_channels: usize,
|
||||
pub norm_num_groups: usize,
|
||||
pub use_quant_conv: bool,
|
||||
pub use_post_quant_conv: bool,
|
||||
}
|
||||
|
||||
impl Default for AutoEncoderKLConfig {
|
||||
@ -284,6 +286,8 @@ impl Default for AutoEncoderKLConfig {
|
||||
layers_per_block: 1,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: true,
|
||||
use_post_quant_conv: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -315,8 +319,8 @@ impl DiagonalGaussianDistribution {
|
||||
pub struct AutoEncoderKL {
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
quant_conv: nn::Conv2d,
|
||||
post_quant_conv: nn::Conv2d,
|
||||
quant_conv: Option<nn::Conv2d>,
|
||||
post_quant_conv: Option<nn::Conv2d>,
|
||||
pub config: AutoEncoderKLConfig,
|
||||
}
|
||||
|
||||
@ -342,20 +346,33 @@ impl AutoEncoderKL {
|
||||
};
|
||||
let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
|
||||
let conv_cfg = Default::default();
|
||||
let quant_conv = nn::conv2d(
|
||||
|
||||
let quant_conv = {
|
||||
if config.use_quant_conv {
|
||||
Some(nn::conv2d(
|
||||
2 * latent_channels,
|
||||
2 * latent_channels,
|
||||
1,
|
||||
conv_cfg,
|
||||
vs.pp("quant_conv"),
|
||||
)?;
|
||||
let post_quant_conv = nn::conv2d(
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
let post_quant_conv = {
|
||||
if config.use_post_quant_conv {
|
||||
Some(nn::conv2d(
|
||||
latent_channels,
|
||||
latent_channels,
|
||||
1,
|
||||
conv_cfg,
|
||||
vs.pp("post_quant_conv"),
|
||||
)?;
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
@ -368,13 +385,19 @@ impl AutoEncoderKL {
|
||||
/// Returns the distribution in the latent space.
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
|
||||
let xs = self.encoder.forward(xs)?;
|
||||
let parameters = self.quant_conv.forward(&xs)?;
|
||||
let parameters = match &self.quant_conv {
|
||||
None => xs,
|
||||
Some(quant_conv) => quant_conv.forward(&xs)?,
|
||||
};
|
||||
DiagonalGaussianDistribution::new(¶meters)
|
||||
}
|
||||
|
||||
/// Takes as input some sampled values.
|
||||
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.post_quant_conv.forward(xs)?;
|
||||
self.decoder.forward(&xs)
|
||||
let xs = match &self.post_quant_conv {
|
||||
None => xs,
|
||||
Some(post_quant_conv) => &post_quant_conv.forward(xs)?,
|
||||
};
|
||||
self.decoder.forward(xs)
|
||||
}
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ yew-agent = "0.2.0"
|
||||
yew = { version = "0.20.0", features = ["csr"] }
|
||||
|
||||
[dependencies.web-sys]
|
||||
version = "0.3.70"
|
||||
version = "=0.3.70"
|
||||
features = [
|
||||
'Blob',
|
||||
'CanvasRenderingContext2d',
|
||||
|
@ -1,3 +1,4 @@
|
||||
#![allow(unused)]
|
||||
use candle::{
|
||||
quantized::{self, k_quants, GgmlDType, GgmlType},
|
||||
test_utils::to_vec2_round,
|
||||
|
Reference in New Issue
Block a user