mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00

* Add stable diffusion 3 example Add get_qkv_linear to handle different dimensionality in linears Add stable diffusion 3 example Add use_quant_conv and use_post_quant_conv for vae in stable diffusion adapt existing AutoEncoderKLConfig to the change add forward_until_encoder_layer to ClipTextTransformer rename sd3 config to sd3_medium in mmdit; minor clean-up Enable flash-attn for mmdit impl when the feature is enabled. Add sd3 example codebase add document crediting references pass the cargo fmt test pass the clippy test * fix typos * expose cfg_scale and time_shift as options * Replace the sample image with JPG version. Change image output format accordingly. * make meaningful error messages * remove the tail-end assignment in sd3_vae_vb_rename * remove the CUDA requirement * use default_value in clap args * add use_flash_attn to turn on/off flash-attn for MMDiT at runtime * resolve clippy errors and warnings * use default_value_t * Pin the web-sys dependency. * Clippy fix. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
186 lines
5.1 KiB
Rust
186 lines
5.1 KiB
Rust
mod clip;
|
|
mod sampling;
|
|
mod vae;
|
|
|
|
use candle::{DType, IndexOp, Tensor};
|
|
use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT};
|
|
|
|
use crate::clip::StableDiffusion3TripleClipWithTokenizer;
|
|
use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
|
|
|
|
use anyhow::{Ok, Result};
|
|
use clap::Parser;
|
|
|
|
#[derive(Parser)]
|
|
#[command(author, version, about, long_about = None)]
|
|
struct Args {
|
|
/// The prompt to be used for image generation.
|
|
#[arg(
|
|
long,
|
|
default_value = "A cute rusty robot holding a candle torch in its hand, \
|
|
with glowing neon text \"LETS GO RUSTY\" displayed on its chest, \
|
|
bright background, high quality, 4k"
|
|
)]
|
|
prompt: String,
|
|
|
|
#[arg(long, default_value = "")]
|
|
uncond_prompt: String,
|
|
|
|
/// Run on CPU rather than on GPU.
|
|
#[arg(long)]
|
|
cpu: bool,
|
|
|
|
/// The CUDA device ID to use.
|
|
#[arg(long, default_value = "0")]
|
|
cuda_device_id: usize,
|
|
|
|
/// Enable tracing (generates a trace-timestamp.json file).
|
|
#[arg(long)]
|
|
tracing: bool,
|
|
|
|
/// Use flash_attn to accelerate attention operation in the MMDiT.
|
|
#[arg(long)]
|
|
use_flash_attn: bool,
|
|
|
|
/// The height in pixels of the generated image.
|
|
#[arg(long, default_value_t = 1024)]
|
|
height: usize,
|
|
|
|
/// The width in pixels of the generated image.
|
|
#[arg(long, default_value_t = 1024)]
|
|
width: usize,
|
|
|
|
/// The seed to use when generating random samples.
|
|
#[arg(long, default_value_t = 28)]
|
|
num_inference_steps: usize,
|
|
|
|
// CFG scale.
|
|
#[arg(long, default_value_t = 4.0)]
|
|
cfg_scale: f64,
|
|
|
|
// Time shift factor (alpha).
|
|
#[arg(long, default_value_t = 3.0)]
|
|
time_shift: f64,
|
|
|
|
/// The seed to use when generating random samples.
|
|
#[arg(long)]
|
|
seed: Option<u64>,
|
|
}
|
|
|
|
fn main() -> Result<()> {
|
|
let args = Args::parse();
|
|
// Your main code here
|
|
run(args)
|
|
}
|
|
|
|
fn run(args: Args) -> Result<()> {
|
|
use tracing_chrome::ChromeLayerBuilder;
|
|
use tracing_subscriber::prelude::*;
|
|
|
|
let Args {
|
|
prompt,
|
|
uncond_prompt,
|
|
cpu,
|
|
cuda_device_id,
|
|
tracing,
|
|
use_flash_attn,
|
|
height,
|
|
width,
|
|
num_inference_steps,
|
|
cfg_scale,
|
|
time_shift,
|
|
seed,
|
|
} = args;
|
|
|
|
let _guard = if tracing {
|
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
tracing_subscriber::registry().with(chrome_layer).init();
|
|
Some(guard)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// TODO: Support and test on Metal.
|
|
let device = if cpu {
|
|
candle::Device::Cpu
|
|
} else {
|
|
candle::Device::cuda_if_available(cuda_device_id)?
|
|
};
|
|
|
|
let api = hf_hub::api::sync::Api::new()?;
|
|
let sai_repo = {
|
|
let name = "stabilityai/stable-diffusion-3-medium";
|
|
api.repo(hf_hub::Repo::model(name.to_string()))
|
|
};
|
|
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
|
|
let vb_fp16 = unsafe {
|
|
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F16, &device)?
|
|
};
|
|
|
|
let (context, y) = {
|
|
let vb_fp32 = unsafe {
|
|
candle_nn::VarBuilder::from_mmaped_safetensors(
|
|
&[model_file.clone()],
|
|
DType::F32,
|
|
&device,
|
|
)?
|
|
};
|
|
let mut triple = StableDiffusion3TripleClipWithTokenizer::new(
|
|
vb_fp16.pp("text_encoders"),
|
|
vb_fp32.pp("text_encoders"),
|
|
)?;
|
|
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
|
|
let (context_uncond, y_uncond) =
|
|
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
|
|
(
|
|
Tensor::cat(&[context, context_uncond], 0)?,
|
|
Tensor::cat(&[y, y_uncond], 0)?,
|
|
)
|
|
};
|
|
|
|
let x = {
|
|
let mmdit = MMDiT::new(
|
|
&MMDiTConfig::sd3_medium(),
|
|
use_flash_attn,
|
|
vb_fp16.pp("model.diffusion_model"),
|
|
)?;
|
|
|
|
if let Some(seed) = seed {
|
|
device.set_seed(seed)?;
|
|
}
|
|
let start_time = std::time::Instant::now();
|
|
let x = sampling::euler_sample(
|
|
&mmdit,
|
|
&y,
|
|
&context,
|
|
num_inference_steps,
|
|
cfg_scale,
|
|
time_shift,
|
|
height,
|
|
width,
|
|
)?;
|
|
let dt = start_time.elapsed().as_secs_f32();
|
|
println!(
|
|
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
|
|
dt,
|
|
num_inference_steps as f32 / dt
|
|
);
|
|
x
|
|
};
|
|
|
|
let img = {
|
|
let vb_vae = vb_fp16
|
|
.clone()
|
|
.rename_f(sd3_vae_vb_rename)
|
|
.pp("first_stage_model");
|
|
let autoencoder = build_sd3_vae_autoencoder(vb_vae)?;
|
|
|
|
// Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
|
|
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
|
|
autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)?
|
|
};
|
|
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
|
|
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
|
|
Ok(())
|
|
}
|