diff --git a/README.md b/README.md
index 20596fe1..f0c96a46 100644
--- a/README.md
+++ b/README.md
@@ -78,7 +78,7 @@ We also provide a some command line based examples using state of the art models
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
- image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
+ image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 3e6de34d..8c3ca2ee 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -11,8 +11,6 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
-const GUIDANCE_SCALE: f64 = 7.5;
-
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -63,8 +61,8 @@ struct Args {
sliced_attention_size: Option,
/// The number of steps to run the diffusion for.
- #[arg(long, default_value_t = 30)]
- n_steps: usize,
+ #[arg(long)]
+ n_steps: Option,
/// The number of samples to generate.
#[arg(long, default_value_t = 1)]
@@ -87,6 +85,9 @@ struct Args {
#[arg(long)]
use_f16: bool,
+ #[arg(long)]
+ guidance_scale: Option,
+
#[arg(long, value_name = "FILE")]
img2img: Option,
@@ -102,6 +103,7 @@ enum StableDiffusionVersion {
V1_5,
V2_1,
Xl,
+ Turbo,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -120,12 +122,13 @@ impl StableDiffusionVersion {
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
+ Self::Turbo => "stabilityai/sdxl-turbo",
}
}
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 | Self::Xl => {
+ Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -137,7 +140,7 @@ impl StableDiffusionVersion {
fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 | Self::Xl => {
+ Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -149,7 +152,7 @@ impl StableDiffusionVersion {
fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 | Self::Xl => {
+ Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
@@ -161,7 +164,7 @@ impl StableDiffusionVersion {
fn clip2_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 | Self::Xl => {
+ Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"text_encoder_2/model.fp16.safetensors"
} else {
@@ -189,7 +192,7 @@ impl ModelFile {
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
"openai/clip-vit-base-patch32"
}
- StableDiffusionVersion::Xl => {
+ StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => {
// This seems similar to the patch32 version except some very small
// difference in the split regex.
"openai/clip-vit-large-patch14"
@@ -206,7 +209,11 @@ impl ModelFile {
Self::Vae => {
// Override for SDXL when using f16 weights.
// See https://github.com/huggingface/candle/issues/1060
- if version == StableDiffusionVersion::Xl && use_f16 {
+ if matches!(
+ version,
+ StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo,
+ ) && use_f16
+ {
(
"madebyollin/sdxl-vae-fp16-fix",
"diffusion_pytorch_model.safetensors",
@@ -261,6 +268,7 @@ fn text_embeddings(
use_f16: bool,
device: &Device,
dtype: DType,
+ use_guide_scale: bool,
first: bool,
) -> Result {
let tokenizer_file = if first {
@@ -285,16 +293,6 @@ fn text_embeddings(
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
- let mut uncond_tokens = tokenizer
- .encode(uncond_prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
- uncond_tokens.push(pad_id)
- }
- let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
-
println!("Building the Clip transformer.");
let clip_weights_file = if first {
ModelFile::Clip
@@ -310,8 +308,24 @@ fn text_embeddings(
let text_model =
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward(&tokens)?;
- let uncond_embeddings = text_model.forward(&uncond_tokens)?;
- let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
+
+ let text_embeddings = if use_guide_scale {
+ let mut uncond_tokens = tokenizer
+ .encode(uncond_prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
+ uncond_tokens.push(pad_id)
+ }
+
+ let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
+ let uncond_embeddings = text_model.forward(&uncond_tokens)?;
+
+ Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
+ } else {
+ text_embeddings.to_dtype(dtype)?
+ };
Ok(text_embeddings)
}
@@ -356,6 +370,7 @@ fn run(args: Args) -> Result<()> {
unet_weights,
tracing,
use_f16,
+ guidance_scale,
use_flash_attn,
img2img,
img2img_strength,
@@ -374,6 +389,24 @@ fn run(args: Args) -> Result<()> {
None
};
+ let guidance_scale = match guidance_scale {
+ Some(guidance_scale) => guidance_scale,
+ None => match sd_version {
+ StableDiffusionVersion::V1_5
+ | StableDiffusionVersion::V2_1
+ | StableDiffusionVersion::Xl => 7.5,
+ StableDiffusionVersion::Turbo => 0.,
+ },
+ };
+ let n_steps = match n_steps {
+ Some(n_steps) => n_steps,
+ None => match sd_version {
+ StableDiffusionVersion::V1_5
+ | StableDiffusionVersion::V2_1
+ | StableDiffusionVersion::Xl => 30,
+ StableDiffusionVersion::Turbo => 1,
+ },
+ };
let dtype = if use_f16 { DType::F16 } else { DType::F32 };
let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => {
@@ -385,13 +418,19 @@ fn run(args: Args) -> Result<()> {
StableDiffusionVersion::Xl => {
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
}
+ StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo(
+ sliced_attention_size,
+ height,
+ width,
+ ),
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
+ let use_guide_scale = guidance_scale > 1.0;
let which = match sd_version {
- StableDiffusionVersion::Xl => vec![true, false],
+ StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false],
_ => vec![true],
};
let text_embeddings = which
@@ -407,10 +446,12 @@ fn run(args: Args) -> Result<()> {
use_f16,
&device,
dtype,
+ use_guide_scale,
*first,
)
})
.collect::>>()?;
+
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
println!("{text_embeddings:?}");
@@ -434,11 +475,19 @@ fn run(args: Args) -> Result<()> {
0
};
let bsize = 1;
+
+ let vae_scale = match sd_version {
+ StableDiffusionVersion::V1_5
+ | StableDiffusionVersion::V2_1
+ | StableDiffusionVersion::Xl => 0.18215,
+ StableDiffusionVersion::Turbo => 0.13025,
+ };
+
for idx in 0..num_samples {
let timesteps = scheduler.timesteps();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
- let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
+ let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
if t_start < timesteps.len() {
let noise = latents.randn_like(0f64, 1f64)?;
scheduler.add_noise(&latents, noise, timesteps[t_start])?
@@ -465,21 +514,31 @@ fn run(args: Args) -> Result<()> {
continue;
}
let start_time = std::time::Instant::now();
- let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
+ let latent_model_input = if use_guide_scale {
+ Tensor::cat(&[&latents, &latents], 0)?
+ } else {
+ latents.clone()
+ };
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
let noise_pred =
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
- let noise_pred = noise_pred.chunk(2, 0)?;
- let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
- let noise_pred =
- (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
+
+ let noise_pred = if use_guide_scale {
+ let noise_pred = noise_pred.chunk(2, 0)?;
+ let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
+
+ (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)?
+ } else {
+ noise_pred
+ };
+
latents = scheduler.step(&noise_pred, timestep, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
if args.intermediary_images {
- let image = vae.decode(&(&latents / 0.18215)?)?;
+ let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename =
@@ -493,7 +552,7 @@ fn run(args: Args) -> Result<()> {
idx + 1,
num_samples
);
- let image = vae.decode(&(&latents / 0.18215)?)?;
+ let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
diff --git a/candle-transformers/src/models/stable_diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs
index b9426094..d804ed56 100644
--- a/candle-transformers/src/models/stable_diffusion/ddim.rs
+++ b/candle-transformers/src/models/stable_diffusion/ddim.rs
@@ -7,7 +7,9 @@
//!
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
//! https://arxiv.org/abs/2010.02502
-use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing};
+use super::schedulers::{
+ betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing,
+};
use candle::{Result, Tensor};
/// The configuration for the DDIM scheduler.
@@ -48,6 +50,12 @@ impl Default for DDIMSchedulerConfig {
}
}
+impl SchedulerConfig for DDIMSchedulerConfig {
+ fn build(&self, inference_steps: usize) -> Result> {
+ Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?))
+ }
+}
+
/// The DDIM scheduler.
#[derive(Debug, Clone)]
pub struct DDIMScheduler {
@@ -63,7 +71,7 @@ impl DDIMScheduler {
/// Creates a new DDIM scheduler given the number of steps to be
/// used for inference as well as the number of steps that was used
/// during training.
- pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result {
+ fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result {
let step_ratio = config.train_timesteps / inference_steps;
let timesteps: Vec = match config.timestep_spacing {
TimestepSpacing::Leading => (0..(inference_steps))
@@ -115,19 +123,11 @@ impl DDIMScheduler {
config,
})
}
+}
- pub fn timesteps(&self) -> &[usize] {
- self.timesteps.as_slice()
- }
-
- /// Ensures interchangeability with schedulers that need to scale the denoising model input
- /// depending on the current timestep.
- pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result {
- Ok(sample)
- }
-
+impl Scheduler for DDIMScheduler {
/// Performs a backward step during inference.
- pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result {
+ fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
@@ -186,7 +186,17 @@ impl DDIMScheduler {
}
}
- pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result {
+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
+ /// depending on the current timestep.
+ fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result {
+ Ok(sample)
+ }
+
+ fn timesteps(&self) -> &[usize] {
+ self.timesteps.as_slice()
+ }
+
+ fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
@@ -197,7 +207,7 @@ impl DDIMScheduler {
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
}
- pub fn init_noise_sigma(&self) -> f64 {
+ fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
}
diff --git a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs
index 7acbf040..85e86e6e 100644
--- a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs
+++ b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs
@@ -8,7 +8,10 @@
///
/// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
use super::{
- schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing},
+ schedulers::{
+ betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig,
+ TimestepSpacing,
+ },
utils::interp,
};
use candle::{bail, Error, Result, Tensor};
@@ -43,11 +46,20 @@ impl Default for EulerAncestralDiscreteSchedulerConfig {
steps_offset: 1,
prediction_type: PredictionType::Epsilon,
train_timesteps: 1000,
- timestep_spacing: TimestepSpacing::Trailing,
+ timestep_spacing: TimestepSpacing::Leading,
}
}
}
+impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig {
+ fn build(&self, inference_steps: usize) -> Result> {
+ Ok(Box::new(EulerAncestralDiscreteScheduler::new(
+ inference_steps,
+ *self,
+ )?))
+ }
+}
+
/// The EulerAncestral Discrete scheduler.
#[derive(Debug, Clone)]
pub struct EulerAncestralDiscreteScheduler {
@@ -138,8 +150,10 @@ impl EulerAncestralDiscreteScheduler {
config,
})
}
+}
- pub fn timesteps(&self) -> &[usize] {
+impl Scheduler for EulerAncestralDiscreteScheduler {
+ fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
@@ -147,7 +161,7 @@ impl EulerAncestralDiscreteScheduler {
/// depending on the current timestep.
///
/// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm
- pub fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result {
+ fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result {
let step_index = match self.timesteps.iter().position(|&t| t == timestep) {
Some(i) => i,
None => bail!("timestep out of this schedulers bounds: {timestep}"),
@@ -162,7 +176,7 @@ impl EulerAncestralDiscreteScheduler {
}
/// Performs a backward step during inference.
- pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result {
+ fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result {
let step_index = self
.timesteps
.iter()
@@ -197,7 +211,7 @@ impl EulerAncestralDiscreteScheduler {
prev_sample + noise * sigma_up
}
- pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result {
+ fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result {
let step_index = self
.timesteps
.iter()
@@ -212,7 +226,7 @@ impl EulerAncestralDiscreteScheduler {
original + (noise * *sigma)?
}
- pub fn init_noise_sigma(&self) -> f64 {
+ fn init_noise_sigma(&self) -> f64 {
match self.config.timestep_spacing {
TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma,
TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(),
diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs
index cad24524..30f23975 100644
--- a/candle-transformers/src/models/stable_diffusion/mod.rs
+++ b/candle-transformers/src/models/stable_diffusion/mod.rs
@@ -11,9 +11,13 @@ pub mod unet_2d_blocks;
pub mod utils;
pub mod vae;
+use std::sync::Arc;
+
use candle::{DType, Device, Result};
use candle_nn as nn;
+use self::schedulers::{Scheduler, SchedulerConfig};
+
#[derive(Clone, Debug)]
pub struct StableDiffusionConfig {
pub width: usize,
@@ -22,7 +26,7 @@ pub struct StableDiffusionConfig {
pub clip2: Option,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
- scheduler: ddim::DDIMSchedulerConfig,
+ scheduler: Arc,
}
impl StableDiffusionConfig {
@@ -76,13 +80,18 @@ impl StableDiffusionConfig {
512
};
- Self {
+ let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
+ prediction_type: schedulers::PredictionType::Epsilon,
+ ..Default::default()
+ });
+
+ StableDiffusionConfig {
width,
height,
clip: clip::Config::v1_5(),
clip2: None,
autoencoder,
- scheduler: Default::default(),
+ scheduler,
unet,
}
}
@@ -125,10 +134,10 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
- let scheduler = ddim::DDIMSchedulerConfig {
+ let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
- };
+ });
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@@ -144,7 +153,7 @@ impl StableDiffusionConfig {
768
};
- Self {
+ StableDiffusionConfig {
width,
height,
clip: clip::Config::v2_1(),
@@ -206,10 +215,10 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
- let scheduler = ddim::DDIMSchedulerConfig {
+ let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
- };
+ });
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@@ -225,6 +234,76 @@ impl StableDiffusionConfig {
1024
};
+ StableDiffusionConfig {
+ width,
+ height,
+ clip: clip::Config::sdxl(),
+ clip2: Some(clip::Config::sdxl2()),
+ autoencoder,
+ scheduler,
+ unet,
+ }
+ }
+
+ fn sdxl_turbo_(
+ sliced_attention_size: Option,
+ height: Option,
+ width: Option,
+ prediction_type: schedulers::PredictionType,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![
+ bc(320, None, 5),
+ bc(640, Some(2), 10),
+ bc(1280, Some(10), 20),
+ ],
+ center_input_sample: false,
+ cross_attention_dim: 2048,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: true,
+ };
+ // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let scheduler = Arc::new(
+ euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig {
+ prediction_type,
+ timestep_spacing: schedulers::TimestepSpacing::Trailing,
+ ..Default::default()
+ },
+ );
+
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
+ height
+ } else {
+ 512
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 512
+ };
+
Self {
width,
height,
@@ -250,6 +329,20 @@ impl StableDiffusionConfig {
)
}
+ pub fn sdxl_turbo(
+ sliced_attention_size: Option,
+ height: Option,
+ width: Option,
+ ) -> Self {
+ Self::sdxl_turbo_(
+ sliced_attention_size,
+ height,
+ width,
+ // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json
+ schedulers::PredictionType::Epsilon,
+ )
+ }
+
pub fn ssd1b(
sliced_attention_size: Option,
height: Option,
@@ -286,9 +379,9 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
- let scheduler = ddim::DDIMSchedulerConfig {
+ let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
..Default::default()
- };
+ });
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@@ -348,8 +441,8 @@ impl StableDiffusionConfig {
Ok(unet)
}
- pub fn build_scheduler(&self, n_steps: usize) -> Result {
- ddim::DDIMScheduler::new(n_steps, self.scheduler)
+ pub fn build_scheduler(&self, n_steps: usize) -> Result> {
+ self.scheduler.build(n_steps)
}
}
diff --git a/candle-transformers/src/models/stable_diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs
index f414bde7..0f0441e0 100644
--- a/candle-transformers/src/models/stable_diffusion/schedulers.rs
+++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs
@@ -3,9 +3,25 @@
//!
//! Noise schedulers can be used to set the trade-off between
//! inference speed and quality.
-
use candle::{Result, Tensor};
+pub trait SchedulerConfig: std::fmt::Debug {
+ fn build(&self, inference_steps: usize) -> Result>;
+}
+
+/// This trait represents a scheduler for the diffusion process.
+pub trait Scheduler {
+ fn timesteps(&self) -> &[usize];
+
+ fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result;
+
+ fn init_noise_sigma(&self) -> f64;
+
+ fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result;
+
+ fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result;
+}
+
/// This represents how beta ranges from its minimum value to the maximum
/// during training.
#[derive(Debug, Clone, Copy)]