Move the stable-diffusion modeling code so that it's easier to re-use. (#812)

This commit is contained in:
Laurent Mazare
2023-09-11 11:45:57 +01:00
committed by GitHub
parent 84ee870efd
commit d7b9fec849
13 changed files with 28 additions and 28 deletions

View File

@ -4,17 +4,7 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
mod attention; use candle_transformers::models::stable_diffusion;
mod clip;
mod ddim;
mod embeddings;
mod resnet;
mod schedulers;
mod stable_diffusion;
mod unet_2d;
mod unet_2d_blocks;
mod utils;
mod vae;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor, D}; use candle::{DType, Device, IndexOp, Tensor, D};

View File

@ -6,4 +6,5 @@ pub mod falcon;
pub mod llama; pub mod llama;
pub mod quantized_llama; pub mod quantized_llama;
pub mod segment_anything; pub mod segment_anything;
pub mod stable_diffusion;
pub mod whisper; pub mod whisper;

View File

@ -7,7 +7,7 @@
//! //!
//! Denoising Diffusion Implicit Models, J. Song et al, 2020. //! Denoising Diffusion Implicit Models, J. Song et al, 2020.
//! https://arxiv.org/abs/2010.02502 //! https://arxiv.org/abs/2010.02502
use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType}; use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
use candle::{Result, Tensor}; use candle::{Result, Tensor};
/// The configuration for the DDIM scheduler. /// The configuration for the DDIM scheduler.
@ -67,14 +67,14 @@ impl DDIMScheduler {
.rev() .rev()
.collect(); .collect();
let betas = match config.beta_schedule { let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => crate::utils::linspace( BetaSchedule::ScaledLinear => super::utils::linspace(
config.beta_start.sqrt(), config.beta_start.sqrt(),
config.beta_end.sqrt(), config.beta_end.sqrt(),
config.train_timesteps, config.train_timesteps,
)? )?
.sqr()?, .sqr()?,
BetaSchedule::Linear => { BetaSchedule::Linear => {
crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)? super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
} }
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?, BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
}; };

View File

@ -1,5 +1,14 @@
use crate::schedulers::PredictionType; pub mod attention;
use crate::{clip, ddim, unet_2d, vae}; pub mod clip;
pub mod ddim;
pub mod embeddings;
pub mod resnet;
pub mod schedulers;
pub mod unet_2d;
pub mod unet_2d_blocks;
pub mod utils;
pub mod vae;
use candle::{DType, Device, Result}; use candle::{DType, Device, Result};
use candle_nn as nn; use candle_nn as nn;
@ -80,7 +89,7 @@ impl StableDiffusionConfig {
sliced_attention_size: Option<usize>, sliced_attention_size: Option<usize>,
height: Option<usize>, height: Option<usize>,
width: Option<usize>, width: Option<usize>,
prediction_type: PredictionType, prediction_type: schedulers::PredictionType,
) -> Self { ) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels, out_channels,
@ -154,7 +163,7 @@ impl StableDiffusionConfig {
sliced_attention_size, sliced_attention_size,
height, height,
width, width,
PredictionType::VPrediction, schedulers::PredictionType::VPrediction,
) )
} }
@ -162,7 +171,7 @@ impl StableDiffusionConfig {
sliced_attention_size: Option<usize>, sliced_attention_size: Option<usize>,
height: Option<usize>, height: Option<usize>,
width: Option<usize>, width: Option<usize>,
prediction_type: PredictionType, prediction_type: schedulers::PredictionType,
) -> Self { ) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels, out_channels,
@ -235,7 +244,7 @@ impl StableDiffusionConfig {
height, height,
width, width,
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
PredictionType::Epsilon, schedulers::PredictionType::Epsilon,
) )
} }

View File

@ -4,7 +4,7 @@
//! //!
//! Denoising Diffusion Implicit Models, K. He and al, 2015. //! Denoising Diffusion Implicit Models, K. He and al, 2015.
//! https://arxiv.org/abs/1512.03385 //! https://arxiv.org/abs/1512.03385
use crate::utils::{conv2d, Conv2d}; use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D}; use candle::{Result, Tensor, D};
use candle_nn as nn; use candle_nn as nn;
use candle_nn::Module; use candle_nn::Module;

View File

@ -2,9 +2,9 @@
//! //!
//! The 2D Unet models take as input a noisy sample and the current diffusion //! The 2D Unet models take as input a noisy sample and the current diffusion
//! timestep and return a denoised version of the input. //! timestep and return a denoised version of the input.
use crate::embeddings::{TimestepEmbedding, Timesteps}; use super::embeddings::{TimestepEmbedding, Timesteps};
use crate::unet_2d_blocks::*; use super::unet_2d_blocks::*;
use crate::utils::{conv2d, Conv2d}; use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor}; use candle::{Result, Tensor};
use candle_nn as nn; use candle_nn as nn;
use candle_nn::Module; use candle_nn::Module;

View File

@ -1,10 +1,10 @@
//! 2D UNet Building Blocks //! 2D UNet Building Blocks
//! //!
use crate::attention::{ use super::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig, AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
}; };
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use crate::utils::{conv2d, Conv2d}; use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D}; use candle::{Result, Tensor, D};
use candle_nn as nn; use candle_nn as nn;

View File

@ -4,7 +4,7 @@
//! Auto-encoder models compress their input to a usually smaller latent space //! Auto-encoder models compress their input to a usually smaller latent space
//! before expanding it back to its original shape. This results in the latent values //! before expanding it back to its original shape. This results in the latent values
//! compressing the original information. //! compressing the original information.
use crate::unet_2d_blocks::{ use super::unet_2d_blocks::{
DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig, DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
UpDecoderBlock2D, UpDecoderBlock2DConfig, UpDecoderBlock2D, UpDecoderBlock2DConfig,
}; };