diff --git a/README.md b/README.md index 20596fe1..f0c96a46 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ We also provide a some command line based examples using state of the art models - [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to - image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions. + image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions. diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 3e6de34d..8c3ca2ee 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -11,8 +11,6 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D}; use clap::Parser; use tokenizers::Tokenizer; -const GUIDANCE_SCALE: f64 = 7.5; - #[derive(Parser)] #[command(author, version, about, long_about = None)] struct Args { @@ -63,8 +61,8 @@ struct Args { sliced_attention_size: Option, /// The number of steps to run the diffusion for. - #[arg(long, default_value_t = 30)] - n_steps: usize, + #[arg(long)] + n_steps: Option, /// The number of samples to generate. #[arg(long, default_value_t = 1)] @@ -87,6 +85,9 @@ struct Args { #[arg(long)] use_f16: bool, + #[arg(long)] + guidance_scale: Option, + #[arg(long, value_name = "FILE")] img2img: Option, @@ -102,6 +103,7 @@ enum StableDiffusionVersion { V1_5, V2_1, Xl, + Turbo, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -120,12 +122,13 @@ impl StableDiffusionVersion { Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", Self::V2_1 => "stabilityai/stable-diffusion-2-1", Self::V1_5 => "runwayml/stable-diffusion-v1-5", + Self::Turbo => "stabilityai/sdxl-turbo", } } fn unet_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { if use_f16 { "unet/diffusion_pytorch_model.fp16.safetensors" } else { @@ -137,7 +140,7 @@ impl StableDiffusionVersion { fn vae_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { if use_f16 { "vae/diffusion_pytorch_model.fp16.safetensors" } else { @@ -149,7 +152,7 @@ impl StableDiffusionVersion { fn clip_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { if use_f16 { "text_encoder/model.fp16.safetensors" } else { @@ -161,7 +164,7 @@ impl StableDiffusionVersion { fn clip2_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { if use_f16 { "text_encoder_2/model.fp16.safetensors" } else { @@ -189,7 +192,7 @@ impl ModelFile { StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { "openai/clip-vit-base-patch32" } - StableDiffusionVersion::Xl => { + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => { // This seems similar to the patch32 version except some very small // difference in the split regex. "openai/clip-vit-large-patch14" @@ -206,7 +209,11 @@ impl ModelFile { Self::Vae => { // Override for SDXL when using f16 weights. // See https://github.com/huggingface/candle/issues/1060 - if version == StableDiffusionVersion::Xl && use_f16 { + if matches!( + version, + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo, + ) && use_f16 + { ( "madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", @@ -261,6 +268,7 @@ fn text_embeddings( use_f16: bool, device: &Device, dtype: DType, + use_guide_scale: bool, first: bool, ) -> Result { let tokenizer_file = if first { @@ -285,16 +293,6 @@ fn text_embeddings( } let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; - let mut uncond_tokens = tokenizer - .encode(uncond_prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - while uncond_tokens.len() < sd_config.clip.max_position_embeddings { - uncond_tokens.push(pad_id) - } - let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; - println!("Building the Clip transformer."); let clip_weights_file = if first { ModelFile::Clip @@ -310,8 +308,24 @@ fn text_embeddings( let text_model = stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?; let text_embeddings = text_model.forward(&tokens)?; - let uncond_embeddings = text_model.forward(&uncond_tokens)?; - let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?; + + let text_embeddings = if use_guide_scale { + let mut uncond_tokens = tokenizer + .encode(uncond_prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + while uncond_tokens.len() < sd_config.clip.max_position_embeddings { + uncond_tokens.push(pad_id) + } + + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; + let uncond_embeddings = text_model.forward(&uncond_tokens)?; + + Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)? + } else { + text_embeddings.to_dtype(dtype)? + }; Ok(text_embeddings) } @@ -356,6 +370,7 @@ fn run(args: Args) -> Result<()> { unet_weights, tracing, use_f16, + guidance_scale, use_flash_attn, img2img, img2img_strength, @@ -374,6 +389,24 @@ fn run(args: Args) -> Result<()> { None }; + let guidance_scale = match guidance_scale { + Some(guidance_scale) => guidance_scale, + None => match sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 7.5, + StableDiffusionVersion::Turbo => 0., + }, + }; + let n_steps = match n_steps { + Some(n_steps) => n_steps, + None => match sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 30, + StableDiffusionVersion::Turbo => 1, + }, + }; let dtype = if use_f16 { DType::F16 } else { DType::F32 }; let sd_config = match sd_version { StableDiffusionVersion::V1_5 => { @@ -385,13 +418,19 @@ fn run(args: Args) -> Result<()> { StableDiffusionVersion::Xl => { stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width) } + StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo( + sliced_attention_size, + height, + width, + ), }; let scheduler = sd_config.build_scheduler(n_steps)?; let device = candle_examples::device(cpu)?; + let use_guide_scale = guidance_scale > 1.0; let which = match sd_version { - StableDiffusionVersion::Xl => vec![true, false], + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], _ => vec![true], }; let text_embeddings = which @@ -407,10 +446,12 @@ fn run(args: Args) -> Result<()> { use_f16, &device, dtype, + use_guide_scale, *first, ) }) .collect::>>()?; + let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; println!("{text_embeddings:?}"); @@ -434,11 +475,19 @@ fn run(args: Args) -> Result<()> { 0 }; let bsize = 1; + + let vae_scale = match sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 0.18215, + StableDiffusionVersion::Turbo => 0.13025, + }; + for idx in 0..num_samples { let timesteps = scheduler.timesteps(); let latents = match &init_latent_dist { Some(init_latent_dist) => { - let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?; + let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?; if t_start < timesteps.len() { let noise = latents.randn_like(0f64, 1f64)?; scheduler.add_noise(&latents, noise, timesteps[t_start])? @@ -465,21 +514,31 @@ fn run(args: Args) -> Result<()> { continue; } let start_time = std::time::Instant::now(); - let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; + let latent_model_input = if use_guide_scale { + Tensor::cat(&[&latents, &latents], 0)? + } else { + latents.clone() + }; let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; let noise_pred = unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; - let noise_pred = noise_pred.chunk(2, 0)?; - let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); - let noise_pred = - (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?; + + let noise_pred = if use_guide_scale { + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); + + (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)? + } else { + noise_pred + }; + latents = scheduler.step(&noise_pred, timestep, &latents)?; let dt = start_time.elapsed().as_secs_f32(); println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); if args.intermediary_images { - let image = vae.decode(&(&latents / 0.18215)?)?; + let image = vae.decode(&(&latents / vae_scale)?)?; let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; let image_filename = @@ -493,7 +552,7 @@ fn run(args: Args) -> Result<()> { idx + 1, num_samples ); - let image = vae.decode(&(&latents / 0.18215)?)?; + let image = vae.decode(&(&latents / vae_scale)?)?; let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; let image_filename = output_filename(&final_image, idx + 1, num_samples, None); diff --git a/candle-transformers/src/models/stable_diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs index b9426094..d804ed56 100644 --- a/candle-transformers/src/models/stable_diffusion/ddim.rs +++ b/candle-transformers/src/models/stable_diffusion/ddim.rs @@ -7,7 +7,9 @@ //! //! Denoising Diffusion Implicit Models, J. Song et al, 2020. //! https://arxiv.org/abs/2010.02502 -use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing}; +use super::schedulers::{ + betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing, +}; use candle::{Result, Tensor}; /// The configuration for the DDIM scheduler. @@ -48,6 +50,12 @@ impl Default for DDIMSchedulerConfig { } } +impl SchedulerConfig for DDIMSchedulerConfig { + fn build(&self, inference_steps: usize) -> Result> { + Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?)) + } +} + /// The DDIM scheduler. #[derive(Debug, Clone)] pub struct DDIMScheduler { @@ -63,7 +71,7 @@ impl DDIMScheduler { /// Creates a new DDIM scheduler given the number of steps to be /// used for inference as well as the number of steps that was used /// during training. - pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result { + fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result { let step_ratio = config.train_timesteps / inference_steps; let timesteps: Vec = match config.timestep_spacing { TimestepSpacing::Leading => (0..(inference_steps)) @@ -115,19 +123,11 @@ impl DDIMScheduler { config, }) } +} - pub fn timesteps(&self) -> &[usize] { - self.timesteps.as_slice() - } - - /// Ensures interchangeability with schedulers that need to scale the denoising model input - /// depending on the current timestep. - pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result { - Ok(sample) - } - +impl Scheduler for DDIMScheduler { /// Performs a backward step during inference. - pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result { + fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result { let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { @@ -186,7 +186,17 @@ impl DDIMScheduler { } } - pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result { + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result { + Ok(sample) + } + + fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result { let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { @@ -197,7 +207,7 @@ impl DDIMScheduler { (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)? } - pub fn init_noise_sigma(&self) -> f64 { + fn init_noise_sigma(&self) -> f64 { self.init_noise_sigma } } diff --git a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs index 7acbf040..85e86e6e 100644 --- a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs +++ b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs @@ -8,7 +8,10 @@ /// /// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 use super::{ - schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing}, + schedulers::{ + betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, + TimestepSpacing, + }, utils::interp, }; use candle::{bail, Error, Result, Tensor}; @@ -43,11 +46,20 @@ impl Default for EulerAncestralDiscreteSchedulerConfig { steps_offset: 1, prediction_type: PredictionType::Epsilon, train_timesteps: 1000, - timestep_spacing: TimestepSpacing::Trailing, + timestep_spacing: TimestepSpacing::Leading, } } } +impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig { + fn build(&self, inference_steps: usize) -> Result> { + Ok(Box::new(EulerAncestralDiscreteScheduler::new( + inference_steps, + *self, + )?)) + } +} + /// The EulerAncestral Discrete scheduler. #[derive(Debug, Clone)] pub struct EulerAncestralDiscreteScheduler { @@ -138,8 +150,10 @@ impl EulerAncestralDiscreteScheduler { config, }) } +} - pub fn timesteps(&self) -> &[usize] { +impl Scheduler for EulerAncestralDiscreteScheduler { + fn timesteps(&self) -> &[usize] { self.timesteps.as_slice() } @@ -147,7 +161,7 @@ impl EulerAncestralDiscreteScheduler { /// depending on the current timestep. /// /// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm - pub fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result { + fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result { let step_index = match self.timesteps.iter().position(|&t| t == timestep) { Some(i) => i, None => bail!("timestep out of this schedulers bounds: {timestep}"), @@ -162,7 +176,7 @@ impl EulerAncestralDiscreteScheduler { } /// Performs a backward step during inference. - pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result { + fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result { let step_index = self .timesteps .iter() @@ -197,7 +211,7 @@ impl EulerAncestralDiscreteScheduler { prev_sample + noise * sigma_up } - pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result { + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result { let step_index = self .timesteps .iter() @@ -212,7 +226,7 @@ impl EulerAncestralDiscreteScheduler { original + (noise * *sigma)? } - pub fn init_noise_sigma(&self) -> f64 { + fn init_noise_sigma(&self) -> f64 { match self.config.timestep_spacing { TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma, TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(), diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index cad24524..30f23975 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -11,9 +11,13 @@ pub mod unet_2d_blocks; pub mod utils; pub mod vae; +use std::sync::Arc; + use candle::{DType, Device, Result}; use candle_nn as nn; +use self::schedulers::{Scheduler, SchedulerConfig}; + #[derive(Clone, Debug)] pub struct StableDiffusionConfig { pub width: usize, @@ -22,7 +26,7 @@ pub struct StableDiffusionConfig { pub clip2: Option, autoencoder: vae::AutoEncoderKLConfig, unet: unet_2d::UNet2DConditionModelConfig, - scheduler: ddim::DDIMSchedulerConfig, + scheduler: Arc, } impl StableDiffusionConfig { @@ -76,13 +80,18 @@ impl StableDiffusionConfig { 512 }; - Self { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { + prediction_type: schedulers::PredictionType::Epsilon, + ..Default::default() + }); + + StableDiffusionConfig { width, height, clip: clip::Config::v1_5(), clip2: None, autoencoder, - scheduler: Default::default(), + scheduler, unet, } } @@ -125,10 +134,10 @@ impl StableDiffusionConfig { latent_channels: 4, norm_num_groups: 32, }; - let scheduler = ddim::DDIMSchedulerConfig { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, ..Default::default() - }; + }); let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -144,7 +153,7 @@ impl StableDiffusionConfig { 768 }; - Self { + StableDiffusionConfig { width, height, clip: clip::Config::v2_1(), @@ -206,10 +215,10 @@ impl StableDiffusionConfig { latent_channels: 4, norm_num_groups: 32, }; - let scheduler = ddim::DDIMSchedulerConfig { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, ..Default::default() - }; + }); let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -225,6 +234,76 @@ impl StableDiffusionConfig { 1024 }; + StableDiffusionConfig { + width, + height, + clip: clip::Config::sdxl(), + clip2: Some(clip::Config::sdxl2()), + autoencoder, + scheduler, + unet, + } + } + + fn sdxl_turbo_( + sliced_attention_size: Option, + height: Option, + width: Option, + prediction_type: schedulers::PredictionType, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, None, 5), + bc(640, Some(2), 10), + bc(1280, Some(10), 20), + ], + center_input_sample: false, + cross_attention_dim: 2048, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: true, + }; + // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let scheduler = Arc::new( + euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig { + prediction_type, + timestep_spacing: schedulers::TimestepSpacing::Trailing, + ..Default::default() + }, + ); + + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "height has to be divisible by 8"); + height + } else { + 512 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 512 + }; + Self { width, height, @@ -250,6 +329,20 @@ impl StableDiffusionConfig { ) } + pub fn sdxl_turbo( + sliced_attention_size: Option, + height: Option, + width: Option, + ) -> Self { + Self::sdxl_turbo_( + sliced_attention_size, + height, + width, + // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json + schedulers::PredictionType::Epsilon, + ) + } + pub fn ssd1b( sliced_attention_size: Option, height: Option, @@ -286,9 +379,9 @@ impl StableDiffusionConfig { latent_channels: 4, norm_num_groups: 32, }; - let scheduler = ddim::DDIMSchedulerConfig { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { ..Default::default() - }; + }); let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -348,8 +441,8 @@ impl StableDiffusionConfig { Ok(unet) } - pub fn build_scheduler(&self, n_steps: usize) -> Result { - ddim::DDIMScheduler::new(n_steps, self.scheduler) + pub fn build_scheduler(&self, n_steps: usize) -> Result> { + self.scheduler.build(n_steps) } } diff --git a/candle-transformers/src/models/stable_diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs index f414bde7..0f0441e0 100644 --- a/candle-transformers/src/models/stable_diffusion/schedulers.rs +++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs @@ -3,9 +3,25 @@ //! //! Noise schedulers can be used to set the trade-off between //! inference speed and quality. - use candle::{Result, Tensor}; +pub trait SchedulerConfig: std::fmt::Debug { + fn build(&self, inference_steps: usize) -> Result>; +} + +/// This trait represents a scheduler for the diffusion process. +pub trait Scheduler { + fn timesteps(&self) -> &[usize]; + + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result; + + fn init_noise_sigma(&self) -> f64; + + fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result; + + fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result; +} + /// This represents how beta ranges from its minimum value to the maximum /// during training. #[derive(Debug, Clone, Copy)]