mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Stable diffusion 3.5 support. (#2578)
* Stable diffusion 3.5 support. * Clippy fixes. * CFG fix. * Remove some unnecessary clones. * Avoid duplicating some of the code.
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
use anyhow::{Error as E, Ok, Result};
|
||||
use candle::{DType, IndexOp, Module, Tensor, D};
|
||||
use candle_transformers::models::{stable_diffusion, t5};
|
||||
use std::path::PathBuf;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
struct ClipWithTokenizer {
|
||||
@ -130,6 +131,53 @@ pub struct StableDiffusion3TripleClipWithTokenizer {
|
||||
}
|
||||
|
||||
impl StableDiffusion3TripleClipWithTokenizer {
|
||||
pub fn new_split(
|
||||
clip_g_file: &PathBuf,
|
||||
clip_l_file: &PathBuf,
|
||||
t5xxl_file: &PathBuf,
|
||||
device: &candle::Device,
|
||||
) -> Result<Self> {
|
||||
let vb_clip_g = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_g_file], DType::F16, device)?
|
||||
};
|
||||
let vb_clip_l = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_l_file], DType::F16, device)?
|
||||
};
|
||||
let vb_t5 = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F32, device)?
|
||||
};
|
||||
let max_position_embeddings = 77usize;
|
||||
let clip_l = ClipWithTokenizer::new(
|
||||
vb_clip_l,
|
||||
stable_diffusion::clip::Config::sdxl(),
|
||||
"openai/clip-vit-large-patch14",
|
||||
max_position_embeddings,
|
||||
)?;
|
||||
|
||||
let text_projection =
|
||||
candle_nn::linear_no_bias(1280, 1280, vb_clip_g.pp("text_projection"))?;
|
||||
|
||||
let clip_g = ClipWithTokenizer::new(
|
||||
vb_clip_g,
|
||||
stable_diffusion::clip::Config::sdxl2(),
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
||||
max_position_embeddings,
|
||||
)?;
|
||||
|
||||
// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
|
||||
// This is a temporary workaround until the T5 implementation is updated to support fp16.
|
||||
// Also see:
|
||||
// https://github.com/huggingface/candle/issues/2480
|
||||
// https://github.com/huggingface/candle/pull/2481
|
||||
let t5 = T5WithTokenizer::new(vb_t5, max_position_embeddings)?;
|
||||
Ok(Self {
|
||||
clip_l,
|
||||
clip_g,
|
||||
clip_g_text_projection: text_projection,
|
||||
t5,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> {
|
||||
let max_position_embeddings = 77usize;
|
||||
let clip_l = ClipWithTokenizer::new(
|
||||
@ -158,7 +206,6 @@ impl StableDiffusion3TripleClipWithTokenizer {
|
||||
// https://github.com/huggingface/candle/issues/2480
|
||||
// https://github.com/huggingface/candle/pull/2481
|
||||
let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?;
|
||||
|
||||
Ok(Self {
|
||||
clip_l,
|
||||
clip_g,
|
||||
@ -195,7 +242,6 @@ impl StableDiffusion3TripleClipWithTokenizer {
|
||||
.encode_text_to_embedding(prompt, device)?
|
||||
.to_dtype(DType::F16)?;
|
||||
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;
|
||||
|
||||
Ok((context, y))
|
||||
}
|
||||
}
|
||||
|
@ -11,6 +11,25 @@ use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
|
||||
use anyhow::{Ok, Result};
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "3-medium")]
|
||||
V3Medium,
|
||||
#[value(name = "3.5-large")]
|
||||
V3_5Large,
|
||||
#[value(name = "3.5-large-turbo")]
|
||||
V3_5LargeTurbo,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn is_3_5(&self) -> bool {
|
||||
match self {
|
||||
Self::V3Medium => false,
|
||||
Self::V3_5Large | Self::V3_5LargeTurbo => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -30,10 +49,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The GPU device ID to use.
|
||||
#[arg(long, default_value_t = 0)]
|
||||
gpu_device_id: usize,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
@ -50,13 +65,17 @@ struct Args {
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
width: usize,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "3-medium")]
|
||||
which: Which,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 28)]
|
||||
num_inference_steps: usize,
|
||||
#[arg(long)]
|
||||
num_inference_steps: Option<usize>,
|
||||
|
||||
// CFG scale.
|
||||
#[arg(long, default_value_t = 4.0)]
|
||||
cfg_scale: f64,
|
||||
#[arg(long)]
|
||||
cfg_scale: Option<f64>,
|
||||
|
||||
// Time shift factor (alpha).
|
||||
#[arg(long, default_value_t = 3.0)]
|
||||
@ -68,12 +87,6 @@ struct Args {
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
// Your main code here
|
||||
run(args)
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
@ -81,7 +94,6 @@ fn run(args: Args) -> Result<()> {
|
||||
prompt,
|
||||
uncond_prompt,
|
||||
cpu,
|
||||
gpu_device_id,
|
||||
tracing,
|
||||
use_flash_attn,
|
||||
height,
|
||||
@ -90,7 +102,8 @@ fn run(args: Args) -> Result<()> {
|
||||
cfg_scale,
|
||||
time_shift,
|
||||
seed,
|
||||
} = args;
|
||||
which,
|
||||
} = Args::parse();
|
||||
|
||||
let _guard = if tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
@ -100,87 +113,110 @@ fn run(args: Args) -> Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
let device = if cpu {
|
||||
candle::Device::Cpu
|
||||
} else if candle::utils::cuda_is_available() {
|
||||
candle::Device::new_cuda(gpu_device_id)?
|
||||
} else if candle::utils::metal_is_available() {
|
||||
candle::Device::new_metal(gpu_device_id)?
|
||||
} else {
|
||||
candle::Device::Cpu
|
||||
let device = candle_examples::device(cpu)?;
|
||||
let default_inference_steps = match which {
|
||||
Which::V3_5Large => 28,
|
||||
Which::V3_5LargeTurbo => 4,
|
||||
Which::V3Medium => 28,
|
||||
};
|
||||
let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps);
|
||||
let default_cfg_scale = match which {
|
||||
Which::V3_5Large => 4.0,
|
||||
Which::V3_5LargeTurbo => 1.0,
|
||||
Which::V3Medium => 4.0,
|
||||
};
|
||||
let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale);
|
||||
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let sai_repo = {
|
||||
let name = "stabilityai/stable-diffusion-3-medium";
|
||||
api.repo(hf_hub::Repo::model(name.to_string()))
|
||||
};
|
||||
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
|
||||
let vb_fp16 = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F16, &device)?
|
||||
};
|
||||
|
||||
let (context, y) = {
|
||||
let vb_fp32 = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(
|
||||
&[model_file.clone()],
|
||||
DType::F32,
|
||||
&device,
|
||||
)?
|
||||
let (mmdit_config, mut triple, vb) = if which.is_3_5() {
|
||||
let sai_repo = {
|
||||
let name = match which {
|
||||
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
|
||||
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
|
||||
Which::V3Medium => unreachable!(),
|
||||
};
|
||||
api.repo(hf_hub::Repo::model(name.to_string()))
|
||||
};
|
||||
let mut triple = StableDiffusion3TripleClipWithTokenizer::new(
|
||||
let clip_g_file = sai_repo.get("text_encoders/clip_g.safetensors")?;
|
||||
let clip_l_file = sai_repo.get("text_encoders/clip_l.safetensors")?;
|
||||
let t5xxl_file = sai_repo.get("text_encoders/t5xxl_fp16.safetensors")?;
|
||||
let model_file = {
|
||||
let model_file = match which {
|
||||
Which::V3_5Large => "sd3.5_large.safetensors",
|
||||
Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors",
|
||||
Which::V3Medium => unreachable!(),
|
||||
};
|
||||
sai_repo.get(model_file)?
|
||||
};
|
||||
let triple = StableDiffusion3TripleClipWithTokenizer::new_split(
|
||||
&clip_g_file,
|
||||
&clip_l_file,
|
||||
&t5xxl_file,
|
||||
&device,
|
||||
)?;
|
||||
let vb = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)?
|
||||
};
|
||||
(MMDiTConfig::sd3_5_large(), triple, vb)
|
||||
} else {
|
||||
let sai_repo = {
|
||||
let name = "stabilityai/stable-diffusion-3-medium";
|
||||
api.repo(hf_hub::Repo::model(name.to_string()))
|
||||
};
|
||||
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
|
||||
let vb_fp16 = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)?
|
||||
};
|
||||
|
||||
let vb_fp32 = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)?
|
||||
};
|
||||
let triple = StableDiffusion3TripleClipWithTokenizer::new(
|
||||
vb_fp16.pp("text_encoders"),
|
||||
vb_fp32.pp("text_encoders"),
|
||||
)?;
|
||||
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
|
||||
let (context_uncond, y_uncond) =
|
||||
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
|
||||
(
|
||||
Tensor::cat(&[context, context_uncond], 0)?,
|
||||
Tensor::cat(&[y, y_uncond], 0)?,
|
||||
)
|
||||
(MMDiTConfig::sd3_medium(), triple, vb_fp16)
|
||||
};
|
||||
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
|
||||
let (context_uncond, y_uncond) =
|
||||
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
|
||||
let context = Tensor::cat(&[context, context_uncond], 0)?;
|
||||
let y = Tensor::cat(&[y, y_uncond], 0)?;
|
||||
|
||||
let x = {
|
||||
let mmdit = MMDiT::new(
|
||||
&MMDiTConfig::sd3_medium(),
|
||||
use_flash_attn,
|
||||
vb_fp16.pp("model.diffusion_model"),
|
||||
)?;
|
||||
let mmdit = MMDiT::new(
|
||||
&mmdit_config,
|
||||
use_flash_attn,
|
||||
vb.pp("model.diffusion_model"),
|
||||
)?;
|
||||
|
||||
if let Some(seed) = seed {
|
||||
device.set_seed(seed)?;
|
||||
}
|
||||
let start_time = std::time::Instant::now();
|
||||
let x = sampling::euler_sample(
|
||||
&mmdit,
|
||||
&y,
|
||||
&context,
|
||||
num_inference_steps,
|
||||
cfg_scale,
|
||||
time_shift,
|
||||
height,
|
||||
width,
|
||||
)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!(
|
||||
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
|
||||
dt,
|
||||
num_inference_steps as f32 / dt
|
||||
);
|
||||
x
|
||||
};
|
||||
if let Some(seed) = seed {
|
||||
device.set_seed(seed)?;
|
||||
}
|
||||
let start_time = std::time::Instant::now();
|
||||
let x = sampling::euler_sample(
|
||||
&mmdit,
|
||||
&y,
|
||||
&context,
|
||||
num_inference_steps,
|
||||
cfg_scale,
|
||||
time_shift,
|
||||
height,
|
||||
width,
|
||||
)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!(
|
||||
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
|
||||
dt,
|
||||
num_inference_steps as f32 / dt
|
||||
);
|
||||
|
||||
let img = {
|
||||
let vb_vae = vb_fp16
|
||||
.clone()
|
||||
.rename_f(sd3_vae_vb_rename)
|
||||
.pp("first_stage_model");
|
||||
let vb_vae = vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
|
||||
let autoencoder = build_sd3_vae_autoencoder(vb_vae)?;
|
||||
|
||||
// Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
|
||||
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
|
||||
autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)?
|
||||
autoencoder.decode(&((x / 1.5305)? + 0.0609)?)?
|
||||
};
|
||||
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
|
||||
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
|
||||
|
@ -30,7 +30,7 @@ pub fn euler_sample(
|
||||
|
||||
let timestep = (*s_curr) * 1000.0;
|
||||
let noise_pred = mmdit.forward(
|
||||
&Tensor::cat(&[x.clone(), x.clone()], 0)?,
|
||||
&Tensor::cat(&[&x, &x], 0)?,
|
||||
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
|
||||
y,
|
||||
context,
|
||||
|
@ -36,6 +36,20 @@ impl Config {
|
||||
frequency_embedding_size: 256,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sd3_5_large() -> Self {
|
||||
Self {
|
||||
patch_size: 2,
|
||||
in_channels: 16,
|
||||
out_channels: 16,
|
||||
depth: 38,
|
||||
head_size: 64,
|
||||
adm_in_channels: 2048,
|
||||
pos_embed_max_size: 192,
|
||||
context_embed_size: 4096,
|
||||
frequency_embedding_size: 256,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MMDiT {
|
||||
|
@ -56,6 +56,8 @@ impl QkvOnlyAttnProjections {
|
||||
pub struct AttnProjections {
|
||||
head_dim: usize,
|
||||
qkv: nn::Linear,
|
||||
ln_k: Option<candle_nn::RmsNorm>,
|
||||
ln_q: Option<candle_nn::RmsNorm>,
|
||||
proj: nn::Linear,
|
||||
}
|
||||
|
||||
@ -64,16 +66,42 @@ impl AttnProjections {
|
||||
let head_dim = dim / num_heads;
|
||||
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
|
||||
let proj = nn::linear(dim, dim, vb.pp("proj"))?;
|
||||
let (ln_k, ln_q) = if vb.contains_tensor("ln_k.weight") {
|
||||
let ln_k = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_k"))?;
|
||||
let ln_q = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_q"))?;
|
||||
(Some(ln_k), Some(ln_q))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
Ok(Self {
|
||||
head_dim,
|
||||
qkv,
|
||||
proj,
|
||||
ln_k,
|
||||
ln_q,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
|
||||
let qkv = self.qkv.forward(x)?;
|
||||
split_qkv(&qkv, self.head_dim)
|
||||
let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?;
|
||||
let q = match self.ln_q.as_ref() {
|
||||
None => q,
|
||||
Some(l) => {
|
||||
let (b, t, h) = q.dims3()?;
|
||||
l.forward(&q.reshape((b, t, (), self.head_dim))?)?
|
||||
.reshape((b, t, h))?
|
||||
}
|
||||
};
|
||||
let k = match self.ln_k.as_ref() {
|
||||
None => k,
|
||||
Some(l) => {
|
||||
let (b, t, h) = k.dims3()?;
|
||||
l.forward(&k.reshape((b, t, (), self.head_dim))?)?
|
||||
.reshape((b, t, h))?
|
||||
}
|
||||
};
|
||||
Ok(Qkv { q, k, v })
|
||||
}
|
||||
|
||||
pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user