mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Stable Diffusion Turbo Support (#1395)
* Add support for SD Turbo * Set Leading as default in euler_ancestral discrete * Use the appropriate default values for n_steps and guidance_scale. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -78,7 +78,7 @@ We also provide a some command line based examples using state of the art models
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
|
||||
|
||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
|
||||
image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||
|
||||
|
@ -11,8 +11,6 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
||||
use clap::Parser;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const GUIDANCE_SCALE: f64 = 7.5;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -63,8 +61,8 @@ struct Args {
|
||||
sliced_attention_size: Option<usize>,
|
||||
|
||||
/// The number of steps to run the diffusion for.
|
||||
#[arg(long, default_value_t = 30)]
|
||||
n_steps: usize,
|
||||
#[arg(long)]
|
||||
n_steps: Option<usize>,
|
||||
|
||||
/// The number of samples to generate.
|
||||
#[arg(long, default_value_t = 1)]
|
||||
@ -87,6 +85,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
use_f16: bool,
|
||||
|
||||
#[arg(long)]
|
||||
guidance_scale: Option<f64>,
|
||||
|
||||
#[arg(long, value_name = "FILE")]
|
||||
img2img: Option<String>,
|
||||
|
||||
@ -102,6 +103,7 @@ enum StableDiffusionVersion {
|
||||
V1_5,
|
||||
V2_1,
|
||||
Xl,
|
||||
Turbo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@ -120,12 +122,13 @@ impl StableDiffusionVersion {
|
||||
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
|
||||
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
|
||||
Self::Turbo => "stabilityai/sdxl-turbo",
|
||||
}
|
||||
}
|
||||
|
||||
fn unet_file(&self, use_f16: bool) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
|
||||
if use_f16 {
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors"
|
||||
} else {
|
||||
@ -137,7 +140,7 @@ impl StableDiffusionVersion {
|
||||
|
||||
fn vae_file(&self, use_f16: bool) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
|
||||
if use_f16 {
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors"
|
||||
} else {
|
||||
@ -149,7 +152,7 @@ impl StableDiffusionVersion {
|
||||
|
||||
fn clip_file(&self, use_f16: bool) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
|
||||
if use_f16 {
|
||||
"text_encoder/model.fp16.safetensors"
|
||||
} else {
|
||||
@ -161,7 +164,7 @@ impl StableDiffusionVersion {
|
||||
|
||||
fn clip2_file(&self, use_f16: bool) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
|
||||
if use_f16 {
|
||||
"text_encoder_2/model.fp16.safetensors"
|
||||
} else {
|
||||
@ -189,7 +192,7 @@ impl ModelFile {
|
||||
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
|
||||
"openai/clip-vit-base-patch32"
|
||||
}
|
||||
StableDiffusionVersion::Xl => {
|
||||
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => {
|
||||
// This seems similar to the patch32 version except some very small
|
||||
// difference in the split regex.
|
||||
"openai/clip-vit-large-patch14"
|
||||
@ -206,7 +209,11 @@ impl ModelFile {
|
||||
Self::Vae => {
|
||||
// Override for SDXL when using f16 weights.
|
||||
// See https://github.com/huggingface/candle/issues/1060
|
||||
if version == StableDiffusionVersion::Xl && use_f16 {
|
||||
if matches!(
|
||||
version,
|
||||
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo,
|
||||
) && use_f16
|
||||
{
|
||||
(
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
@ -261,6 +268,7 @@ fn text_embeddings(
|
||||
use_f16: bool,
|
||||
device: &Device,
|
||||
dtype: DType,
|
||||
use_guide_scale: bool,
|
||||
first: bool,
|
||||
) -> Result<Tensor> {
|
||||
let tokenizer_file = if first {
|
||||
@ -285,16 +293,6 @@ fn text_embeddings(
|
||||
}
|
||||
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
let mut uncond_tokens = tokenizer
|
||||
.encode(uncond_prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
println!("Building the Clip transformer.");
|
||||
let clip_weights_file = if first {
|
||||
ModelFile::Clip
|
||||
@ -310,8 +308,24 @@ fn text_embeddings(
|
||||
let text_model =
|
||||
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
|
||||
let text_embeddings = text_model.forward(&tokens)?;
|
||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
|
||||
|
||||
let text_embeddings = if use_guide_scale {
|
||||
let mut uncond_tokens = tokenizer
|
||||
.encode(uncond_prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||
|
||||
Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
|
||||
} else {
|
||||
text_embeddings.to_dtype(dtype)?
|
||||
};
|
||||
Ok(text_embeddings)
|
||||
}
|
||||
|
||||
@ -356,6 +370,7 @@ fn run(args: Args) -> Result<()> {
|
||||
unet_weights,
|
||||
tracing,
|
||||
use_f16,
|
||||
guidance_scale,
|
||||
use_flash_attn,
|
||||
img2img,
|
||||
img2img_strength,
|
||||
@ -374,6 +389,24 @@ fn run(args: Args) -> Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
let guidance_scale = match guidance_scale {
|
||||
Some(guidance_scale) => guidance_scale,
|
||||
None => match sd_version {
|
||||
StableDiffusionVersion::V1_5
|
||||
| StableDiffusionVersion::V2_1
|
||||
| StableDiffusionVersion::Xl => 7.5,
|
||||
StableDiffusionVersion::Turbo => 0.,
|
||||
},
|
||||
};
|
||||
let n_steps = match n_steps {
|
||||
Some(n_steps) => n_steps,
|
||||
None => match sd_version {
|
||||
StableDiffusionVersion::V1_5
|
||||
| StableDiffusionVersion::V2_1
|
||||
| StableDiffusionVersion::Xl => 30,
|
||||
StableDiffusionVersion::Turbo => 1,
|
||||
},
|
||||
};
|
||||
let dtype = if use_f16 { DType::F16 } else { DType::F32 };
|
||||
let sd_config = match sd_version {
|
||||
StableDiffusionVersion::V1_5 => {
|
||||
@ -385,13 +418,19 @@ fn run(args: Args) -> Result<()> {
|
||||
StableDiffusionVersion::Xl => {
|
||||
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
|
||||
}
|
||||
StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo(
|
||||
sliced_attention_size,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
};
|
||||
|
||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||
let device = candle_examples::device(cpu)?;
|
||||
let use_guide_scale = guidance_scale > 1.0;
|
||||
|
||||
let which = match sd_version {
|
||||
StableDiffusionVersion::Xl => vec![true, false],
|
||||
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false],
|
||||
_ => vec![true],
|
||||
};
|
||||
let text_embeddings = which
|
||||
@ -407,10 +446,12 @@ fn run(args: Args) -> Result<()> {
|
||||
use_f16,
|
||||
&device,
|
||||
dtype,
|
||||
use_guide_scale,
|
||||
*first,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
|
||||
println!("{text_embeddings:?}");
|
||||
|
||||
@ -434,11 +475,19 @@ fn run(args: Args) -> Result<()> {
|
||||
0
|
||||
};
|
||||
let bsize = 1;
|
||||
|
||||
let vae_scale = match sd_version {
|
||||
StableDiffusionVersion::V1_5
|
||||
| StableDiffusionVersion::V2_1
|
||||
| StableDiffusionVersion::Xl => 0.18215,
|
||||
StableDiffusionVersion::Turbo => 0.13025,
|
||||
};
|
||||
|
||||
for idx in 0..num_samples {
|
||||
let timesteps = scheduler.timesteps();
|
||||
let latents = match &init_latent_dist {
|
||||
Some(init_latent_dist) => {
|
||||
let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
|
||||
let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
|
||||
if t_start < timesteps.len() {
|
||||
let noise = latents.randn_like(0f64, 1f64)?;
|
||||
scheduler.add_noise(&latents, noise, timesteps[t_start])?
|
||||
@ -465,21 +514,31 @@ fn run(args: Args) -> Result<()> {
|
||||
continue;
|
||||
}
|
||||
let start_time = std::time::Instant::now();
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
let latent_model_input = if use_guide_scale {
|
||||
Tensor::cat(&[&latents, &latents], 0)?
|
||||
} else {
|
||||
latents.clone()
|
||||
};
|
||||
|
||||
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
|
||||
let noise_pred =
|
||||
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
|
||||
let noise_pred = noise_pred.chunk(2, 0)?;
|
||||
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
|
||||
let noise_pred =
|
||||
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
|
||||
|
||||
let noise_pred = if use_guide_scale {
|
||||
let noise_pred = noise_pred.chunk(2, 0)?;
|
||||
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
|
||||
|
||||
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)?
|
||||
} else {
|
||||
noise_pred
|
||||
};
|
||||
|
||||
latents = scheduler.step(&noise_pred, timestep, &latents)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
|
||||
|
||||
if args.intermediary_images {
|
||||
let image = vae.decode(&(&latents / 0.18215)?)?;
|
||||
let image = vae.decode(&(&latents / vae_scale)?)?;
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename =
|
||||
@ -493,7 +552,7 @@ fn run(args: Args) -> Result<()> {
|
||||
idx + 1,
|
||||
num_samples
|
||||
);
|
||||
let image = vae.decode(&(&latents / 0.18215)?)?;
|
||||
let image = vae.decode(&(&latents / vae_scale)?)?;
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||
|
@ -7,7 +7,9 @@
|
||||
//!
|
||||
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
|
||||
//! https://arxiv.org/abs/2010.02502
|
||||
use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing};
|
||||
use super::schedulers::{
|
||||
betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing,
|
||||
};
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
/// The configuration for the DDIM scheduler.
|
||||
@ -48,6 +50,12 @@ impl Default for DDIMSchedulerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl SchedulerConfig for DDIMSchedulerConfig {
|
||||
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
|
||||
Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?))
|
||||
}
|
||||
}
|
||||
|
||||
/// The DDIM scheduler.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DDIMScheduler {
|
||||
@ -63,7 +71,7 @@ impl DDIMScheduler {
|
||||
/// Creates a new DDIM scheduler given the number of steps to be
|
||||
/// used for inference as well as the number of steps that was used
|
||||
/// during training.
|
||||
pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
|
||||
fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
|
||||
let step_ratio = config.train_timesteps / inference_steps;
|
||||
let timesteps: Vec<usize> = match config.timestep_spacing {
|
||||
TimestepSpacing::Leading => (0..(inference_steps))
|
||||
@ -115,19 +123,11 @@ impl DDIMScheduler {
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn timesteps(&self) -> &[usize] {
|
||||
self.timesteps.as_slice()
|
||||
}
|
||||
|
||||
/// Ensures interchangeability with schedulers that need to scale the denoising model input
|
||||
/// depending on the current timestep.
|
||||
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
|
||||
Ok(sample)
|
||||
}
|
||||
|
||||
impl Scheduler for DDIMScheduler {
|
||||
/// Performs a backward step during inference.
|
||||
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
let timestep = if timestep >= self.alphas_cumprod.len() {
|
||||
timestep - 1
|
||||
} else {
|
||||
@ -186,7 +186,17 @@ impl DDIMScheduler {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
/// Ensures interchangeability with schedulers that need to scale the denoising model input
|
||||
/// depending on the current timestep.
|
||||
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
|
||||
Ok(sample)
|
||||
}
|
||||
|
||||
fn timesteps(&self) -> &[usize] {
|
||||
self.timesteps.as_slice()
|
||||
}
|
||||
|
||||
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
let timestep = if timestep >= self.alphas_cumprod.len() {
|
||||
timestep - 1
|
||||
} else {
|
||||
@ -197,7 +207,7 @@ impl DDIMScheduler {
|
||||
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
|
||||
}
|
||||
|
||||
pub fn init_noise_sigma(&self) -> f64 {
|
||||
fn init_noise_sigma(&self) -> f64 {
|
||||
self.init_noise_sigma
|
||||
}
|
||||
}
|
||||
|
@ -8,7 +8,10 @@
|
||||
///
|
||||
/// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
|
||||
use super::{
|
||||
schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing},
|
||||
schedulers::{
|
||||
betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig,
|
||||
TimestepSpacing,
|
||||
},
|
||||
utils::interp,
|
||||
};
|
||||
use candle::{bail, Error, Result, Tensor};
|
||||
@ -43,11 +46,20 @@ impl Default for EulerAncestralDiscreteSchedulerConfig {
|
||||
steps_offset: 1,
|
||||
prediction_type: PredictionType::Epsilon,
|
||||
train_timesteps: 1000,
|
||||
timestep_spacing: TimestepSpacing::Trailing,
|
||||
timestep_spacing: TimestepSpacing::Leading,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig {
|
||||
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
|
||||
Ok(Box::new(EulerAncestralDiscreteScheduler::new(
|
||||
inference_steps,
|
||||
*self,
|
||||
)?))
|
||||
}
|
||||
}
|
||||
|
||||
/// The EulerAncestral Discrete scheduler.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EulerAncestralDiscreteScheduler {
|
||||
@ -138,8 +150,10 @@ impl EulerAncestralDiscreteScheduler {
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn timesteps(&self) -> &[usize] {
|
||||
impl Scheduler for EulerAncestralDiscreteScheduler {
|
||||
fn timesteps(&self) -> &[usize] {
|
||||
self.timesteps.as_slice()
|
||||
}
|
||||
|
||||
@ -147,7 +161,7 @@ impl EulerAncestralDiscreteScheduler {
|
||||
/// depending on the current timestep.
|
||||
///
|
||||
/// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm
|
||||
pub fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
let step_index = match self.timesteps.iter().position(|&t| t == timestep) {
|
||||
Some(i) => i,
|
||||
None => bail!("timestep out of this schedulers bounds: {timestep}"),
|
||||
@ -162,7 +176,7 @@ impl EulerAncestralDiscreteScheduler {
|
||||
}
|
||||
|
||||
/// Performs a backward step during inference.
|
||||
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
let step_index = self
|
||||
.timesteps
|
||||
.iter()
|
||||
@ -197,7 +211,7 @@ impl EulerAncestralDiscreteScheduler {
|
||||
prev_sample + noise * sigma_up
|
||||
}
|
||||
|
||||
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
let step_index = self
|
||||
.timesteps
|
||||
.iter()
|
||||
@ -212,7 +226,7 @@ impl EulerAncestralDiscreteScheduler {
|
||||
original + (noise * *sigma)?
|
||||
}
|
||||
|
||||
pub fn init_noise_sigma(&self) -> f64 {
|
||||
fn init_noise_sigma(&self) -> f64 {
|
||||
match self.config.timestep_spacing {
|
||||
TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma,
|
||||
TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(),
|
||||
|
@ -11,9 +11,13 @@ pub mod unet_2d_blocks;
|
||||
pub mod utils;
|
||||
pub mod vae;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use candle::{DType, Device, Result};
|
||||
use candle_nn as nn;
|
||||
|
||||
use self::schedulers::{Scheduler, SchedulerConfig};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct StableDiffusionConfig {
|
||||
pub width: usize,
|
||||
@ -22,7 +26,7 @@ pub struct StableDiffusionConfig {
|
||||
pub clip2: Option<clip::Config>,
|
||||
autoencoder: vae::AutoEncoderKLConfig,
|
||||
unet: unet_2d::UNet2DConditionModelConfig,
|
||||
scheduler: ddim::DDIMSchedulerConfig,
|
||||
scheduler: Arc<dyn SchedulerConfig>,
|
||||
}
|
||||
|
||||
impl StableDiffusionConfig {
|
||||
@ -76,13 +80,18 @@ impl StableDiffusionConfig {
|
||||
512
|
||||
};
|
||||
|
||||
Self {
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
prediction_type: schedulers::PredictionType::Epsilon,
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
StableDiffusionConfig {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::v1_5(),
|
||||
clip2: None,
|
||||
autoencoder,
|
||||
scheduler: Default::default(),
|
||||
scheduler,
|
||||
unet,
|
||||
}
|
||||
}
|
||||
@ -125,10 +134,10 @@ impl StableDiffusionConfig {
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = ddim::DDIMSchedulerConfig {
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
prediction_type,
|
||||
..Default::default()
|
||||
};
|
||||
});
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
@ -144,7 +153,7 @@ impl StableDiffusionConfig {
|
||||
768
|
||||
};
|
||||
|
||||
Self {
|
||||
StableDiffusionConfig {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::v2_1(),
|
||||
@ -206,10 +215,10 @@ impl StableDiffusionConfig {
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = ddim::DDIMSchedulerConfig {
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
prediction_type,
|
||||
..Default::default()
|
||||
};
|
||||
});
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
@ -225,6 +234,76 @@ impl StableDiffusionConfig {
|
||||
1024
|
||||
};
|
||||
|
||||
StableDiffusionConfig {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::sdxl(),
|
||||
clip2: Some(clip::Config::sdxl2()),
|
||||
autoencoder,
|
||||
scheduler,
|
||||
unet,
|
||||
}
|
||||
}
|
||||
|
||||
fn sdxl_turbo_(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
prediction_type: schedulers::PredictionType,
|
||||
) -> Self {
|
||||
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json
|
||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||
blocks: vec![
|
||||
bc(320, None, 5),
|
||||
bc(640, Some(2), 10),
|
||||
bc(1280, Some(10), 20),
|
||||
],
|
||||
center_input_sample: false,
|
||||
cross_attention_dim: 2048,
|
||||
downsample_padding: 1,
|
||||
flip_sin_to_cos: true,
|
||||
freq_shift: 0.,
|
||||
layers_per_block: 2,
|
||||
mid_block_scale_factor: 1.,
|
||||
norm_eps: 1e-5,
|
||||
norm_num_groups: 32,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: true,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json
|
||||
let autoencoder = vae::AutoEncoderKLConfig {
|
||||
block_out_channels: vec![128, 256, 512, 512],
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = Arc::new(
|
||||
euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig {
|
||||
prediction_type,
|
||||
timestep_spacing: schedulers::TimestepSpacing::Trailing,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
height
|
||||
} else {
|
||||
512
|
||||
};
|
||||
|
||||
let width = if let Some(width) = width {
|
||||
assert_eq!(width % 8, 0, "width has to be divisible by 8");
|
||||
width
|
||||
} else {
|
||||
512
|
||||
};
|
||||
|
||||
Self {
|
||||
width,
|
||||
height,
|
||||
@ -250,6 +329,20 @@ impl StableDiffusionConfig {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn sdxl_turbo(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
) -> Self {
|
||||
Self::sdxl_turbo_(
|
||||
sliced_attention_size,
|
||||
height,
|
||||
width,
|
||||
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json
|
||||
schedulers::PredictionType::Epsilon,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn ssd1b(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
@ -286,9 +379,9 @@ impl StableDiffusionConfig {
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = ddim::DDIMSchedulerConfig {
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
..Default::default()
|
||||
};
|
||||
});
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
@ -348,8 +441,8 @@ impl StableDiffusionConfig {
|
||||
Ok(unet)
|
||||
}
|
||||
|
||||
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
|
||||
ddim::DDIMScheduler::new(n_steps, self.scheduler)
|
||||
pub fn build_scheduler(&self, n_steps: usize) -> Result<Box<dyn Scheduler>> {
|
||||
self.scheduler.build(n_steps)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,9 +3,25 @@
|
||||
//!
|
||||
//! Noise schedulers can be used to set the trade-off between
|
||||
//! inference speed and quality.
|
||||
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
pub trait SchedulerConfig: std::fmt::Debug {
|
||||
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>>;
|
||||
}
|
||||
|
||||
/// This trait represents a scheduler for the diffusion process.
|
||||
pub trait Scheduler {
|
||||
fn timesteps(&self) -> &[usize];
|
||||
|
||||
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor>;
|
||||
|
||||
fn init_noise_sigma(&self) -> f64;
|
||||
|
||||
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
|
||||
|
||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
/// This represents how beta ranges from its minimum value to the maximum
|
||||
/// during training.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
Reference in New Issue
Block a user