Stable diffusion 3.5 support. (#2578)

* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.
This commit is contained in:
Laurent Mazare
2024-10-27 10:01:04 +01:00
committed by GitHub
parent 07849aa595
commit 37e0ab8c64
5 changed files with 209 additions and 85 deletions

View File

@ -11,6 +11,25 @@ use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
use anyhow::{Ok, Result};
use clap::Parser;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "3-medium")]
V3Medium,
#[value(name = "3.5-large")]
V3_5Large,
#[value(name = "3.5-large-turbo")]
V3_5LargeTurbo,
}
impl Which {
fn is_3_5(&self) -> bool {
match self {
Self::V3Medium => false,
Self::V3_5Large | Self::V3_5LargeTurbo => true,
}
}
}
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -30,10 +49,6 @@ struct Args {
#[arg(long)]
cpu: bool,
/// The GPU device ID to use.
#[arg(long, default_value_t = 0)]
gpu_device_id: usize,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -50,13 +65,17 @@ struct Args {
#[arg(long, default_value_t = 1024)]
width: usize,
/// The model to use.
#[arg(long, default_value = "3-medium")]
which: Which,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 28)]
num_inference_steps: usize,
#[arg(long)]
num_inference_steps: Option<usize>,
// CFG scale.
#[arg(long, default_value_t = 4.0)]
cfg_scale: f64,
#[arg(long)]
cfg_scale: Option<f64>,
// Time shift factor (alpha).
#[arg(long, default_value_t = 3.0)]
@ -68,12 +87,6 @@ struct Args {
}
fn main() -> Result<()> {
let args = Args::parse();
// Your main code here
run(args)
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
@ -81,7 +94,6 @@ fn run(args: Args) -> Result<()> {
prompt,
uncond_prompt,
cpu,
gpu_device_id,
tracing,
use_flash_attn,
height,
@ -90,7 +102,8 @@ fn run(args: Args) -> Result<()> {
cfg_scale,
time_shift,
seed,
} = args;
which,
} = Args::parse();
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
@ -100,87 +113,110 @@ fn run(args: Args) -> Result<()> {
None
};
let device = if cpu {
candle::Device::Cpu
} else if candle::utils::cuda_is_available() {
candle::Device::new_cuda(gpu_device_id)?
} else if candle::utils::metal_is_available() {
candle::Device::new_metal(gpu_device_id)?
} else {
candle::Device::Cpu
let device = candle_examples::device(cpu)?;
let default_inference_steps = match which {
Which::V3_5Large => 28,
Which::V3_5LargeTurbo => 4,
Which::V3Medium => 28,
};
let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps);
let default_cfg_scale = match which {
Which::V3_5Large => 4.0,
Which::V3_5LargeTurbo => 1.0,
Which::V3Medium => 4.0,
};
let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale);
let api = hf_hub::api::sync::Api::new()?;
let sai_repo = {
let name = "stabilityai/stable-diffusion-3-medium";
api.repo(hf_hub::Repo::model(name.to_string()))
};
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
let vb_fp16 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F16, &device)?
};
let (context, y) = {
let vb_fp32 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(
&[model_file.clone()],
DType::F32,
&device,
)?
let (mmdit_config, mut triple, vb) = if which.is_3_5() {
let sai_repo = {
let name = match which {
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
Which::V3Medium => unreachable!(),
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let mut triple = StableDiffusion3TripleClipWithTokenizer::new(
let clip_g_file = sai_repo.get("text_encoders/clip_g.safetensors")?;
let clip_l_file = sai_repo.get("text_encoders/clip_l.safetensors")?;
let t5xxl_file = sai_repo.get("text_encoders/t5xxl_fp16.safetensors")?;
let model_file = {
let model_file = match which {
Which::V3_5Large => "sd3.5_large.safetensors",
Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors",
Which::V3Medium => unreachable!(),
};
sai_repo.get(model_file)?
};
let triple = StableDiffusion3TripleClipWithTokenizer::new_split(
&clip_g_file,
&clip_l_file,
&t5xxl_file,
&device,
)?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)?
};
(MMDiTConfig::sd3_5_large(), triple, vb)
} else {
let sai_repo = {
let name = "stabilityai/stable-diffusion-3-medium";
api.repo(hf_hub::Repo::model(name.to_string()))
};
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
let vb_fp16 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)?
};
let vb_fp32 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)?
};
let triple = StableDiffusion3TripleClipWithTokenizer::new(
vb_fp16.pp("text_encoders"),
vb_fp32.pp("text_encoders"),
)?;
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
let (context_uncond, y_uncond) =
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
(
Tensor::cat(&[context, context_uncond], 0)?,
Tensor::cat(&[y, y_uncond], 0)?,
)
(MMDiTConfig::sd3_medium(), triple, vb_fp16)
};
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
let (context_uncond, y_uncond) =
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
let context = Tensor::cat(&[context, context_uncond], 0)?;
let y = Tensor::cat(&[y, y_uncond], 0)?;
let x = {
let mmdit = MMDiT::new(
&MMDiTConfig::sd3_medium(),
use_flash_attn,
vb_fp16.pp("model.diffusion_model"),
)?;
let mmdit = MMDiT::new(
&mmdit_config,
use_flash_attn,
vb.pp("model.diffusion_model"),
)?;
if let Some(seed) = seed {
device.set_seed(seed)?;
}
let start_time = std::time::Instant::now();
let x = sampling::euler_sample(
&mmdit,
&y,
&context,
num_inference_steps,
cfg_scale,
time_shift,
height,
width,
)?;
let dt = start_time.elapsed().as_secs_f32();
println!(
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
dt,
num_inference_steps as f32 / dt
);
x
};
if let Some(seed) = seed {
device.set_seed(seed)?;
}
let start_time = std::time::Instant::now();
let x = sampling::euler_sample(
&mmdit,
&y,
&context,
num_inference_steps,
cfg_scale,
time_shift,
height,
width,
)?;
let dt = start_time.elapsed().as_secs_f32();
println!(
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
dt,
num_inference_steps as f32 / dt
);
let img = {
let vb_vae = vb_fp16
.clone()
.rename_f(sd3_vae_vb_rename)
.pp("first_stage_model");
let vb_vae = vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
let autoencoder = build_sd3_vae_autoencoder(vb_vae)?;
// Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)?
autoencoder.decode(&((x / 1.5305)? + 0.0609)?)?
};
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;