Stable diffusion 3.5 support. (#2578)

* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.
This commit is contained in:
Laurent Mazare
2024-10-27 10:01:04 +01:00
committed by GitHub
parent 07849aa595
commit 37e0ab8c64
5 changed files with 209 additions and 85 deletions

View File

@ -1,6 +1,7 @@
use anyhow::{Error as E, Ok, Result}; use anyhow::{Error as E, Ok, Result};
use candle::{DType, IndexOp, Module, Tensor, D}; use candle::{DType, IndexOp, Module, Tensor, D};
use candle_transformers::models::{stable_diffusion, t5}; use candle_transformers::models::{stable_diffusion, t5};
use std::path::PathBuf;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
struct ClipWithTokenizer { struct ClipWithTokenizer {
@ -130,6 +131,53 @@ pub struct StableDiffusion3TripleClipWithTokenizer {
} }
impl StableDiffusion3TripleClipWithTokenizer { impl StableDiffusion3TripleClipWithTokenizer {
pub fn new_split(
clip_g_file: &PathBuf,
clip_l_file: &PathBuf,
t5xxl_file: &PathBuf,
device: &candle::Device,
) -> Result<Self> {
let vb_clip_g = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_g_file], DType::F16, device)?
};
let vb_clip_l = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_l_file], DType::F16, device)?
};
let vb_t5 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F32, device)?
};
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
vb_clip_l,
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;
let text_projection =
candle_nn::linear_no_bias(1280, 1280, vb_clip_g.pp("text_projection"))?;
let clip_g = ClipWithTokenizer::new(
vb_clip_g,
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;
// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
// This is a temporary workaround until the T5 implementation is updated to support fp16.
// Also see:
// https://github.com/huggingface/candle/issues/2480
// https://github.com/huggingface/candle/pull/2481
let t5 = T5WithTokenizer::new(vb_t5, max_position_embeddings)?;
Ok(Self {
clip_l,
clip_g,
clip_g_text_projection: text_projection,
t5,
})
}
pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> { pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> {
let max_position_embeddings = 77usize; let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new( let clip_l = ClipWithTokenizer::new(
@ -158,7 +206,6 @@ impl StableDiffusion3TripleClipWithTokenizer {
// https://github.com/huggingface/candle/issues/2480 // https://github.com/huggingface/candle/issues/2480
// https://github.com/huggingface/candle/pull/2481 // https://github.com/huggingface/candle/pull/2481
let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?; let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?;
Ok(Self { Ok(Self {
clip_l, clip_l,
clip_g, clip_g,
@ -195,7 +242,6 @@ impl StableDiffusion3TripleClipWithTokenizer {
.encode_text_to_embedding(prompt, device)? .encode_text_to_embedding(prompt, device)?
.to_dtype(DType::F16)?; .to_dtype(DType::F16)?;
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?; let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;
Ok((context, y)) Ok((context, y))
} }
} }

View File

@ -11,6 +11,25 @@ use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
use anyhow::{Ok, Result}; use anyhow::{Ok, Result};
use clap::Parser; use clap::Parser;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "3-medium")]
V3Medium,
#[value(name = "3.5-large")]
V3_5Large,
#[value(name = "3.5-large-turbo")]
V3_5LargeTurbo,
}
impl Which {
fn is_3_5(&self) -> bool {
match self {
Self::V3Medium => false,
Self::V3_5Large | Self::V3_5LargeTurbo => true,
}
}
}
#[derive(Parser)] #[derive(Parser)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
@ -30,10 +49,6 @@ struct Args {
#[arg(long)] #[arg(long)]
cpu: bool, cpu: bool,
/// The GPU device ID to use.
#[arg(long, default_value_t = 0)]
gpu_device_id: usize,
/// Enable tracing (generates a trace-timestamp.json file). /// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)] #[arg(long)]
tracing: bool, tracing: bool,
@ -50,13 +65,17 @@ struct Args {
#[arg(long, default_value_t = 1024)] #[arg(long, default_value_t = 1024)]
width: usize, width: usize,
/// The model to use.
#[arg(long, default_value = "3-medium")]
which: Which,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 28)] #[arg(long)]
num_inference_steps: usize, num_inference_steps: Option<usize>,
// CFG scale. // CFG scale.
#[arg(long, default_value_t = 4.0)] #[arg(long)]
cfg_scale: f64, cfg_scale: Option<f64>,
// Time shift factor (alpha). // Time shift factor (alpha).
#[arg(long, default_value_t = 3.0)] #[arg(long, default_value_t = 3.0)]
@ -68,12 +87,6 @@ struct Args {
} }
fn main() -> Result<()> { fn main() -> Result<()> {
let args = Args::parse();
// Your main code here
run(args)
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder; use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
@ -81,7 +94,6 @@ fn run(args: Args) -> Result<()> {
prompt, prompt,
uncond_prompt, uncond_prompt,
cpu, cpu,
gpu_device_id,
tracing, tracing,
use_flash_attn, use_flash_attn,
height, height,
@ -90,7 +102,8 @@ fn run(args: Args) -> Result<()> {
cfg_scale, cfg_scale,
time_shift, time_shift,
seed, seed,
} = args; which,
} = Args::parse();
let _guard = if tracing { let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
@ -100,52 +113,80 @@ fn run(args: Args) -> Result<()> {
None None
}; };
let device = if cpu { let device = candle_examples::device(cpu)?;
candle::Device::Cpu let default_inference_steps = match which {
} else if candle::utils::cuda_is_available() { Which::V3_5Large => 28,
candle::Device::new_cuda(gpu_device_id)? Which::V3_5LargeTurbo => 4,
} else if candle::utils::metal_is_available() { Which::V3Medium => 28,
candle::Device::new_metal(gpu_device_id)?
} else {
candle::Device::Cpu
}; };
let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps);
let default_cfg_scale = match which {
Which::V3_5Large => 4.0,
Which::V3_5LargeTurbo => 1.0,
Which::V3Medium => 4.0,
};
let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale);
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let (mmdit_config, mut triple, vb) = if which.is_3_5() {
let sai_repo = {
let name = match which {
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
Which::V3Medium => unreachable!(),
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let clip_g_file = sai_repo.get("text_encoders/clip_g.safetensors")?;
let clip_l_file = sai_repo.get("text_encoders/clip_l.safetensors")?;
let t5xxl_file = sai_repo.get("text_encoders/t5xxl_fp16.safetensors")?;
let model_file = {
let model_file = match which {
Which::V3_5Large => "sd3.5_large.safetensors",
Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors",
Which::V3Medium => unreachable!(),
};
sai_repo.get(model_file)?
};
let triple = StableDiffusion3TripleClipWithTokenizer::new_split(
&clip_g_file,
&clip_l_file,
&t5xxl_file,
&device,
)?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)?
};
(MMDiTConfig::sd3_5_large(), triple, vb)
} else {
let sai_repo = { let sai_repo = {
let name = "stabilityai/stable-diffusion-3-medium"; let name = "stabilityai/stable-diffusion-3-medium";
api.repo(hf_hub::Repo::model(name.to_string())) api.repo(hf_hub::Repo::model(name.to_string()))
}; };
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?; let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
let vb_fp16 = unsafe { let vb_fp16 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F16, &device)? candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)?
}; };
let (context, y) = {
let vb_fp32 = unsafe { let vb_fp32 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors( candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)?
&[model_file.clone()],
DType::F32,
&device,
)?
}; };
let mut triple = StableDiffusion3TripleClipWithTokenizer::new( let triple = StableDiffusion3TripleClipWithTokenizer::new(
vb_fp16.pp("text_encoders"), vb_fp16.pp("text_encoders"),
vb_fp32.pp("text_encoders"), vb_fp32.pp("text_encoders"),
)?; )?;
(MMDiTConfig::sd3_medium(), triple, vb_fp16)
};
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?; let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
let (context_uncond, y_uncond) = let (context_uncond, y_uncond) =
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?; triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
( let context = Tensor::cat(&[context, context_uncond], 0)?;
Tensor::cat(&[context, context_uncond], 0)?, let y = Tensor::cat(&[y, y_uncond], 0)?;
Tensor::cat(&[y, y_uncond], 0)?,
)
};
let x = {
let mmdit = MMDiT::new( let mmdit = MMDiT::new(
&MMDiTConfig::sd3_medium(), &mmdit_config,
use_flash_attn, use_flash_attn,
vb_fp16.pp("model.diffusion_model"), vb.pp("model.diffusion_model"),
)?; )?;
if let Some(seed) = seed { if let Some(seed) = seed {
@ -168,19 +209,14 @@ fn run(args: Args) -> Result<()> {
dt, dt,
num_inference_steps as f32 / dt num_inference_steps as f32 / dt
); );
x
};
let img = { let img = {
let vb_vae = vb_fp16 let vb_vae = vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
.clone()
.rename_f(sd3_vae_vb_rename)
.pp("first_stage_model");
let autoencoder = build_sd3_vae_autoencoder(vb_vae)?; let autoencoder = build_sd3_vae_autoencoder(vb_vae)?;
// Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image. // Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723 // https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)? autoencoder.decode(&((x / 1.5305)? + 0.0609)?)?
}; };
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?; let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?; candle_examples::save_image(&img.i(0)?, "out.jpg")?;

View File

@ -30,7 +30,7 @@ pub fn euler_sample(
let timestep = (*s_curr) * 1000.0; let timestep = (*s_curr) * 1000.0;
let noise_pred = mmdit.forward( let noise_pred = mmdit.forward(
&Tensor::cat(&[x.clone(), x.clone()], 0)?, &Tensor::cat(&[&x, &x], 0)?,
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?, &Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y, y,
context, context,

View File

@ -36,6 +36,20 @@ impl Config {
frequency_embedding_size: 256, frequency_embedding_size: 256,
} }
} }
pub fn sd3_5_large() -> Self {
Self {
patch_size: 2,
in_channels: 16,
out_channels: 16,
depth: 38,
head_size: 64,
adm_in_channels: 2048,
pos_embed_max_size: 192,
context_embed_size: 4096,
frequency_embedding_size: 256,
}
}
} }
pub struct MMDiT { pub struct MMDiT {

View File

@ -56,6 +56,8 @@ impl QkvOnlyAttnProjections {
pub struct AttnProjections { pub struct AttnProjections {
head_dim: usize, head_dim: usize,
qkv: nn::Linear, qkv: nn::Linear,
ln_k: Option<candle_nn::RmsNorm>,
ln_q: Option<candle_nn::RmsNorm>,
proj: nn::Linear, proj: nn::Linear,
} }
@ -64,16 +66,42 @@ impl AttnProjections {
let head_dim = dim / num_heads; let head_dim = dim / num_heads;
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
let proj = nn::linear(dim, dim, vb.pp("proj"))?; let proj = nn::linear(dim, dim, vb.pp("proj"))?;
let (ln_k, ln_q) = if vb.contains_tensor("ln_k.weight") {
let ln_k = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_k"))?;
let ln_q = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_q"))?;
(Some(ln_k), Some(ln_q))
} else {
(None, None)
};
Ok(Self { Ok(Self {
head_dim, head_dim,
qkv, qkv,
proj, proj,
ln_k,
ln_q,
}) })
} }
pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> { pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
let qkv = self.qkv.forward(x)?; let qkv = self.qkv.forward(x)?;
split_qkv(&qkv, self.head_dim) let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?;
let q = match self.ln_q.as_ref() {
None => q,
Some(l) => {
let (b, t, h) = q.dims3()?;
l.forward(&q.reshape((b, t, (), self.head_dim))?)?
.reshape((b, t, h))?
}
};
let k = match self.ln_k.as_ref() {
None => k,
Some(l) => {
let (b, t, h) = k.dims3()?;
l.forward(&k.reshape((b, t, (), self.head_dim))?)?
.reshape((b, t, h))?
}
};
Ok(Qkv { q, k, v })
} }
pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> { pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {