Move the stable-diffusion modeling code so that it's easier to re-use. (#812)

This commit is contained in:
Laurent Mazare
2023-09-11 11:45:57 +01:00
committed by GitHub
parent 84ee870efd
commit d7b9fec849
13 changed files with 28 additions and 28 deletions

View File

@ -4,17 +4,7 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod attention;
mod clip;
mod ddim;
mod embeddings;
mod resnet;
mod schedulers;
mod stable_diffusion;
mod unet_2d;
mod unet_2d_blocks;
mod utils;
mod vae;
use candle_transformers::models::stable_diffusion;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor, D};

View File

@ -6,4 +6,5 @@ pub mod falcon;
pub mod llama;
pub mod quantized_llama;
pub mod segment_anything;
pub mod stable_diffusion;
pub mod whisper;

View File

@ -7,7 +7,7 @@
//!
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
//! https://arxiv.org/abs/2010.02502
use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
use candle::{Result, Tensor};
/// The configuration for the DDIM scheduler.
@ -67,14 +67,14 @@ impl DDIMScheduler {
.rev()
.collect();
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => crate::utils::linspace(
BetaSchedule::ScaledLinear => super::utils::linspace(
config.beta_start.sqrt(),
config.beta_end.sqrt(),
config.train_timesteps,
)?
.sqr()?,
BetaSchedule::Linear => {
crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
}
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
};

View File

@ -1,5 +1,14 @@
use crate::schedulers::PredictionType;
use crate::{clip, ddim, unet_2d, vae};
pub mod attention;
pub mod clip;
pub mod ddim;
pub mod embeddings;
pub mod resnet;
pub mod schedulers;
pub mod unet_2d;
pub mod unet_2d_blocks;
pub mod utils;
pub mod vae;
use candle::{DType, Device, Result};
use candle_nn as nn;
@ -80,7 +89,7 @@ impl StableDiffusionConfig {
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
prediction_type: PredictionType,
prediction_type: schedulers::PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
@ -154,7 +163,7 @@ impl StableDiffusionConfig {
sliced_attention_size,
height,
width,
PredictionType::VPrediction,
schedulers::PredictionType::VPrediction,
)
}
@ -162,7 +171,7 @@ impl StableDiffusionConfig {
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
prediction_type: PredictionType,
prediction_type: schedulers::PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
@ -235,7 +244,7 @@ impl StableDiffusionConfig {
height,
width,
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
PredictionType::Epsilon,
schedulers::PredictionType::Epsilon,
)
}

View File

@ -4,7 +4,7 @@
//!
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
//! https://arxiv.org/abs/1512.03385
use crate::utils::{conv2d, Conv2d};
use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;

View File

@ -2,9 +2,9 @@
//!
//! The 2D Unet models take as input a noisy sample and the current diffusion
//! timestep and return a denoised version of the input.
use crate::embeddings::{TimestepEmbedding, Timesteps};
use crate::unet_2d_blocks::*;
use crate::utils::{conv2d, Conv2d};
use super::embeddings::{TimestepEmbedding, Timesteps};
use super::unet_2d_blocks::*;
use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor};
use candle_nn as nn;
use candle_nn::Module;

View File

@ -1,10 +1,10 @@
//! 2D UNet Building Blocks
//!
use crate::attention::{
use super::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use crate::utils::{conv2d, Conv2d};
use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;

View File

@ -4,7 +4,7 @@
//! Auto-encoder models compress their input to a usually smaller latent space
//! before expanding it back to its original shape. This results in the latent values
//! compressing the original information.
use crate::unet_2d_blocks::{
use super::unet_2d_blocks::{
DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
UpDecoderBlock2D, UpDecoderBlock2DConfig,
};