Stable Diffusion Turbo Support (#1395)

* Add support for SD Turbo

* Set Leading as default in euler_ancestral discrete

* Use the appropriate default values for n_steps and guidance_scale.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Edwin Cheng
2023-12-03 15:37:10 +08:00
committed by GitHub
parent dd40edfe73
commit 37bf1ed012
6 changed files with 259 additions and 67 deletions

View File

@ -78,7 +78,7 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">

View File

@ -11,8 +11,6 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
const GUIDANCE_SCALE: f64 = 7.5;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -63,8 +61,8 @@ struct Args {
sliced_attention_size: Option<usize>,
/// The number of steps to run the diffusion for.
#[arg(long, default_value_t = 30)]
n_steps: usize,
#[arg(long)]
n_steps: Option<usize>,
/// The number of samples to generate.
#[arg(long, default_value_t = 1)]
@ -87,6 +85,9 @@ struct Args {
#[arg(long)]
use_f16: bool,
#[arg(long)]
guidance_scale: Option<f64>,
#[arg(long, value_name = "FILE")]
img2img: Option<String>,
@ -102,6 +103,7 @@ enum StableDiffusionVersion {
V1_5,
V2_1,
Xl,
Turbo,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -120,12 +122,13 @@ impl StableDiffusionVersion {
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
Self::Turbo => "stabilityai/sdxl-turbo",
}
}
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl => {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
@ -137,7 +140,7 @@ impl StableDiffusionVersion {
fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl => {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
@ -149,7 +152,7 @@ impl StableDiffusionVersion {
fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl => {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
@ -161,7 +164,7 @@ impl StableDiffusionVersion {
fn clip2_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl => {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"text_encoder_2/model.fp16.safetensors"
} else {
@ -189,7 +192,7 @@ impl ModelFile {
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
"openai/clip-vit-base-patch32"
}
StableDiffusionVersion::Xl => {
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => {
// This seems similar to the patch32 version except some very small
// difference in the split regex.
"openai/clip-vit-large-patch14"
@ -206,7 +209,11 @@ impl ModelFile {
Self::Vae => {
// Override for SDXL when using f16 weights.
// See https://github.com/huggingface/candle/issues/1060
if version == StableDiffusionVersion::Xl && use_f16 {
if matches!(
version,
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo,
) && use_f16
{
(
"madebyollin/sdxl-vae-fp16-fix",
"diffusion_pytorch_model.safetensors",
@ -261,6 +268,7 @@ fn text_embeddings(
use_f16: bool,
device: &Device,
dtype: DType,
use_guide_scale: bool,
first: bool,
) -> Result<Tensor> {
let tokenizer_file = if first {
@ -285,16 +293,6 @@ fn text_embeddings(
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
println!("Building the Clip transformer.");
let clip_weights_file = if first {
ModelFile::Clip
@ -310,8 +308,24 @@ fn text_embeddings(
let text_model =
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward(&tokens)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
let text_embeddings = if use_guide_scale {
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
} else {
text_embeddings.to_dtype(dtype)?
};
Ok(text_embeddings)
}
@ -356,6 +370,7 @@ fn run(args: Args) -> Result<()> {
unet_weights,
tracing,
use_f16,
guidance_scale,
use_flash_attn,
img2img,
img2img_strength,
@ -374,6 +389,24 @@ fn run(args: Args) -> Result<()> {
None
};
let guidance_scale = match guidance_scale {
Some(guidance_scale) => guidance_scale,
None => match sd_version {
StableDiffusionVersion::V1_5
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 7.5,
StableDiffusionVersion::Turbo => 0.,
},
};
let n_steps = match n_steps {
Some(n_steps) => n_steps,
None => match sd_version {
StableDiffusionVersion::V1_5
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 30,
StableDiffusionVersion::Turbo => 1,
},
};
let dtype = if use_f16 { DType::F16 } else { DType::F32 };
let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => {
@ -385,13 +418,19 @@ fn run(args: Args) -> Result<()> {
StableDiffusionVersion::Xl => {
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
}
StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo(
sliced_attention_size,
height,
width,
),
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
let use_guide_scale = guidance_scale > 1.0;
let which = match sd_version {
StableDiffusionVersion::Xl => vec![true, false],
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false],
_ => vec![true],
};
let text_embeddings = which
@ -407,10 +446,12 @@ fn run(args: Args) -> Result<()> {
use_f16,
&device,
dtype,
use_guide_scale,
*first,
)
})
.collect::<Result<Vec<_>>>()?;
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
println!("{text_embeddings:?}");
@ -434,11 +475,19 @@ fn run(args: Args) -> Result<()> {
0
};
let bsize = 1;
let vae_scale = match sd_version {
StableDiffusionVersion::V1_5
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 0.18215,
StableDiffusionVersion::Turbo => 0.13025,
};
for idx in 0..num_samples {
let timesteps = scheduler.timesteps();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
if t_start < timesteps.len() {
let noise = latents.randn_like(0f64, 1f64)?;
scheduler.add_noise(&latents, noise, timesteps[t_start])?
@ -465,21 +514,31 @@ fn run(args: Args) -> Result<()> {
continue;
}
let start_time = std::time::Instant::now();
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
let latent_model_input = if use_guide_scale {
Tensor::cat(&[&latents, &latents], 0)?
} else {
latents.clone()
};
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
let noise_pred =
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
let noise_pred =
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
let noise_pred = if use_guide_scale {
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)?
} else {
noise_pred
};
latents = scheduler.step(&noise_pred, timestep, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
if args.intermediary_images {
let image = vae.decode(&(&latents / 0.18215)?)?;
let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename =
@ -493,7 +552,7 @@ fn run(args: Args) -> Result<()> {
idx + 1,
num_samples
);
let image = vae.decode(&(&latents / 0.18215)?)?;
let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);

View File

@ -7,7 +7,9 @@
//!
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
//! https://arxiv.org/abs/2010.02502
use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing};
use super::schedulers::{
betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing,
};
use candle::{Result, Tensor};
/// The configuration for the DDIM scheduler.
@ -48,6 +50,12 @@ impl Default for DDIMSchedulerConfig {
}
}
impl SchedulerConfig for DDIMSchedulerConfig {
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?))
}
}
/// The DDIM scheduler.
#[derive(Debug, Clone)]
pub struct DDIMScheduler {
@ -63,7 +71,7 @@ impl DDIMScheduler {
/// Creates a new DDIM scheduler given the number of steps to be
/// used for inference as well as the number of steps that was used
/// during training.
pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
let step_ratio = config.train_timesteps / inference_steps;
let timesteps: Vec<usize> = match config.timestep_spacing {
TimestepSpacing::Leading => (0..(inference_steps))
@ -115,19 +123,11 @@ impl DDIMScheduler {
config,
})
}
}
pub fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
/// Ensures interchangeability with schedulers that need to scale the denoising model input
/// depending on the current timestep.
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
Ok(sample)
}
impl Scheduler for DDIMScheduler {
/// Performs a backward step during inference.
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
@ -186,7 +186,17 @@ impl DDIMScheduler {
}
}
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
/// Ensures interchangeability with schedulers that need to scale the denoising model input
/// depending on the current timestep.
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
Ok(sample)
}
fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
@ -197,7 +207,7 @@ impl DDIMScheduler {
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
}
pub fn init_noise_sigma(&self) -> f64 {
fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
}

View File

@ -8,7 +8,10 @@
///
/// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
use super::{
schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing},
schedulers::{
betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig,
TimestepSpacing,
},
utils::interp,
};
use candle::{bail, Error, Result, Tensor};
@ -43,11 +46,20 @@ impl Default for EulerAncestralDiscreteSchedulerConfig {
steps_offset: 1,
prediction_type: PredictionType::Epsilon,
train_timesteps: 1000,
timestep_spacing: TimestepSpacing::Trailing,
timestep_spacing: TimestepSpacing::Leading,
}
}
}
impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig {
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
Ok(Box::new(EulerAncestralDiscreteScheduler::new(
inference_steps,
*self,
)?))
}
}
/// The EulerAncestral Discrete scheduler.
#[derive(Debug, Clone)]
pub struct EulerAncestralDiscreteScheduler {
@ -138,8 +150,10 @@ impl EulerAncestralDiscreteScheduler {
config,
})
}
}
pub fn timesteps(&self) -> &[usize] {
impl Scheduler for EulerAncestralDiscreteScheduler {
fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
@ -147,7 +161,7 @@ impl EulerAncestralDiscreteScheduler {
/// depending on the current timestep.
///
/// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm
pub fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
let step_index = match self.timesteps.iter().position(|&t| t == timestep) {
Some(i) => i,
None => bail!("timestep out of this schedulers bounds: {timestep}"),
@ -162,7 +176,7 @@ impl EulerAncestralDiscreteScheduler {
}
/// Performs a backward step during inference.
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let step_index = self
.timesteps
.iter()
@ -197,7 +211,7 @@ impl EulerAncestralDiscreteScheduler {
prev_sample + noise * sigma_up
}
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
let step_index = self
.timesteps
.iter()
@ -212,7 +226,7 @@ impl EulerAncestralDiscreteScheduler {
original + (noise * *sigma)?
}
pub fn init_noise_sigma(&self) -> f64 {
fn init_noise_sigma(&self) -> f64 {
match self.config.timestep_spacing {
TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma,
TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(),

View File

@ -11,9 +11,13 @@ pub mod unet_2d_blocks;
pub mod utils;
pub mod vae;
use std::sync::Arc;
use candle::{DType, Device, Result};
use candle_nn as nn;
use self::schedulers::{Scheduler, SchedulerConfig};
#[derive(Clone, Debug)]
pub struct StableDiffusionConfig {
pub width: usize,
@ -22,7 +26,7 @@ pub struct StableDiffusionConfig {
pub clip2: Option<clip::Config>,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: ddim::DDIMSchedulerConfig,
scheduler: Arc<dyn SchedulerConfig>,
}
impl StableDiffusionConfig {
@ -76,13 +80,18 @@ impl StableDiffusionConfig {
512
};
Self {
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type: schedulers::PredictionType::Epsilon,
..Default::default()
});
StableDiffusionConfig {
width,
height,
clip: clip::Config::v1_5(),
clip2: None,
autoencoder,
scheduler: Default::default(),
scheduler,
unet,
}
}
@ -125,10 +134,10 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
};
});
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@ -144,7 +153,7 @@ impl StableDiffusionConfig {
768
};
Self {
StableDiffusionConfig {
width,
height,
clip: clip::Config::v2_1(),
@ -206,10 +215,10 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
};
});
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@ -225,6 +234,76 @@ impl StableDiffusionConfig {
1024
};
StableDiffusionConfig {
width,
height,
clip: clip::Config::sdxl(),
clip2: Some(clip::Config::sdxl2()),
autoencoder,
scheduler,
unet,
}
}
fn sdxl_turbo_(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
prediction_type: schedulers::PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, None, 5),
bc(640, Some(2), 10),
bc(1280, Some(10), 20),
],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = Arc::new(
euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig {
prediction_type,
timestep_spacing: schedulers::TimestepSpacing::Trailing,
..Default::default()
},
);
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
512
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
512
};
Self {
width,
height,
@ -250,6 +329,20 @@ impl StableDiffusionConfig {
)
}
pub fn sdxl_turbo(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
Self::sdxl_turbo_(
sliced_attention_size,
height,
width,
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json
schedulers::PredictionType::Epsilon,
)
}
pub fn ssd1b(
sliced_attention_size: Option<usize>,
height: Option<usize>,
@ -286,9 +379,9 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
..Default::default()
};
});
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@ -348,8 +441,8 @@ impl StableDiffusionConfig {
Ok(unet)
}
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
pub fn build_scheduler(&self, n_steps: usize) -> Result<Box<dyn Scheduler>> {
self.scheduler.build(n_steps)
}
}

View File

@ -3,9 +3,25 @@
//!
//! Noise schedulers can be used to set the trade-off between
//! inference speed and quality.
use candle::{Result, Tensor};
pub trait SchedulerConfig: std::fmt::Debug {
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>>;
}
/// This trait represents a scheduler for the diffusion process.
pub trait Scheduler {
fn timesteps(&self) -> &[usize];
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor>;
fn init_noise_sigma(&self) -> f64;
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
}
/// This represents how beta ranges from its minimum value to the maximum
/// during training.
#[derive(Debug, Clone, Copy)]