mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Move the stable-diffusion modeling code so that it's easier to re-use. (#812)
This commit is contained in:
@ -4,17 +4,7 @@ extern crate accelerate_src;
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
mod attention;
|
use candle_transformers::models::stable_diffusion;
|
||||||
mod clip;
|
|
||||||
mod ddim;
|
|
||||||
mod embeddings;
|
|
||||||
mod resnet;
|
|
||||||
mod schedulers;
|
|
||||||
mod stable_diffusion;
|
|
||||||
mod unet_2d;
|
|
||||||
mod unet_2d_blocks;
|
|
||||||
mod utils;
|
|
||||||
mod vae;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, IndexOp, Tensor, D};
|
use candle::{DType, Device, IndexOp, Tensor, D};
|
||||||
|
@ -6,4 +6,5 @@ pub mod falcon;
|
|||||||
pub mod llama;
|
pub mod llama;
|
||||||
pub mod quantized_llama;
|
pub mod quantized_llama;
|
||||||
pub mod segment_anything;
|
pub mod segment_anything;
|
||||||
|
pub mod stable_diffusion;
|
||||||
pub mod whisper;
|
pub mod whisper;
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
//!
|
//!
|
||||||
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
|
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
|
||||||
//! https://arxiv.org/abs/2010.02502
|
//! https://arxiv.org/abs/2010.02502
|
||||||
use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
|
use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
|
||||||
use candle::{Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
/// The configuration for the DDIM scheduler.
|
/// The configuration for the DDIM scheduler.
|
||||||
@ -67,14 +67,14 @@ impl DDIMScheduler {
|
|||||||
.rev()
|
.rev()
|
||||||
.collect();
|
.collect();
|
||||||
let betas = match config.beta_schedule {
|
let betas = match config.beta_schedule {
|
||||||
BetaSchedule::ScaledLinear => crate::utils::linspace(
|
BetaSchedule::ScaledLinear => super::utils::linspace(
|
||||||
config.beta_start.sqrt(),
|
config.beta_start.sqrt(),
|
||||||
config.beta_end.sqrt(),
|
config.beta_end.sqrt(),
|
||||||
config.train_timesteps,
|
config.train_timesteps,
|
||||||
)?
|
)?
|
||||||
.sqr()?,
|
.sqr()?,
|
||||||
BetaSchedule::Linear => {
|
BetaSchedule::Linear => {
|
||||||
crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
|
super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
|
||||||
}
|
}
|
||||||
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
|
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
|
||||||
};
|
};
|
@ -1,5 +1,14 @@
|
|||||||
use crate::schedulers::PredictionType;
|
pub mod attention;
|
||||||
use crate::{clip, ddim, unet_2d, vae};
|
pub mod clip;
|
||||||
|
pub mod ddim;
|
||||||
|
pub mod embeddings;
|
||||||
|
pub mod resnet;
|
||||||
|
pub mod schedulers;
|
||||||
|
pub mod unet_2d;
|
||||||
|
pub mod unet_2d_blocks;
|
||||||
|
pub mod utils;
|
||||||
|
pub mod vae;
|
||||||
|
|
||||||
use candle::{DType, Device, Result};
|
use candle::{DType, Device, Result};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
|
|
||||||
@ -80,7 +89,7 @@ impl StableDiffusionConfig {
|
|||||||
sliced_attention_size: Option<usize>,
|
sliced_attention_size: Option<usize>,
|
||||||
height: Option<usize>,
|
height: Option<usize>,
|
||||||
width: Option<usize>,
|
width: Option<usize>,
|
||||||
prediction_type: PredictionType,
|
prediction_type: schedulers::PredictionType,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -154,7 +163,7 @@ impl StableDiffusionConfig {
|
|||||||
sliced_attention_size,
|
sliced_attention_size,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
PredictionType::VPrediction,
|
schedulers::PredictionType::VPrediction,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,7 +171,7 @@ impl StableDiffusionConfig {
|
|||||||
sliced_attention_size: Option<usize>,
|
sliced_attention_size: Option<usize>,
|
||||||
height: Option<usize>,
|
height: Option<usize>,
|
||||||
width: Option<usize>,
|
width: Option<usize>,
|
||||||
prediction_type: PredictionType,
|
prediction_type: schedulers::PredictionType,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -235,7 +244,7 @@ impl StableDiffusionConfig {
|
|||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
|
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
|
||||||
PredictionType::Epsilon,
|
schedulers::PredictionType::Epsilon,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
@ -4,7 +4,7 @@
|
|||||||
//!
|
//!
|
||||||
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
|
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
|
||||||
//! https://arxiv.org/abs/1512.03385
|
//! https://arxiv.org/abs/1512.03385
|
||||||
use crate::utils::{conv2d, Conv2d};
|
use super::utils::{conv2d, Conv2d};
|
||||||
use candle::{Result, Tensor, D};
|
use candle::{Result, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
use candle_nn::Module;
|
use candle_nn::Module;
|
@ -2,9 +2,9 @@
|
|||||||
//!
|
//!
|
||||||
//! The 2D Unet models take as input a noisy sample and the current diffusion
|
//! The 2D Unet models take as input a noisy sample and the current diffusion
|
||||||
//! timestep and return a denoised version of the input.
|
//! timestep and return a denoised version of the input.
|
||||||
use crate::embeddings::{TimestepEmbedding, Timesteps};
|
use super::embeddings::{TimestepEmbedding, Timesteps};
|
||||||
use crate::unet_2d_blocks::*;
|
use super::unet_2d_blocks::*;
|
||||||
use crate::utils::{conv2d, Conv2d};
|
use super::utils::{conv2d, Conv2d};
|
||||||
use candle::{Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
use candle_nn::Module;
|
use candle_nn::Module;
|
@ -1,10 +1,10 @@
|
|||||||
//! 2D UNet Building Blocks
|
//! 2D UNet Building Blocks
|
||||||
//!
|
//!
|
||||||
use crate::attention::{
|
use super::attention::{
|
||||||
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
|
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
|
||||||
};
|
};
|
||||||
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
||||||
use crate::utils::{conv2d, Conv2d};
|
use super::utils::{conv2d, Conv2d};
|
||||||
use candle::{Result, Tensor, D};
|
use candle::{Result, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
|
|
@ -4,7 +4,7 @@
|
|||||||
//! Auto-encoder models compress their input to a usually smaller latent space
|
//! Auto-encoder models compress their input to a usually smaller latent space
|
||||||
//! before expanding it back to its original shape. This results in the latent values
|
//! before expanding it back to its original shape. This results in the latent values
|
||||||
//! compressing the original information.
|
//! compressing the original information.
|
||||||
use crate::unet_2d_blocks::{
|
use super::unet_2d_blocks::{
|
||||||
DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
|
DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
|
||||||
UpDecoderBlock2D, UpDecoderBlock2DConfig,
|
UpDecoderBlock2D, UpDecoderBlock2DConfig,
|
||||||
};
|
};
|
Reference in New Issue
Block a user