Move the stable-diffusion modeling code so that it's easier to re-use. (#812)

This commit is contained in:
Laurent Mazare
2023-09-11 11:45:57 +01:00
committed by GitHub
parent 84ee870efd
commit d7b9fec849
13 changed files with 28 additions and 28 deletions

View File

@ -6,4 +6,5 @@ pub mod falcon;
pub mod llama;
pub mod quantized_llama;
pub mod segment_anything;
pub mod stable_diffusion;
pub mod whisper;

View File

@ -0,0 +1,545 @@
//! Attention Based Building Blocks
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug)]
struct GeGlu {
proj: nn::Linear,
span: tracing::Span,
}
impl GeGlu {
fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
let span = tracing::span!(tracing::Level::TRACE, "geglu");
Ok(Self { proj, span })
}
}
impl GeGlu {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
}
}
/// A feed-forward layer.
#[derive(Debug)]
struct FeedForward {
project_in: GeGlu,
linear: nn::Linear,
span: tracing::Span,
}
impl FeedForward {
// The glu parameter in the python code is unused?
// https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347
/// Creates a new feed-forward layer based on some given input dimension, some
/// output dimension, and a multiplier to be used for the intermediary layer.
fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
let inner_dim = dim * mult;
let dim_out = dim_out.unwrap_or(dim);
let vs = vs.pp("net");
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
let span = tracing::span!(tracing::Level::TRACE, "ff");
Ok(Self {
project_in,
linear,
span,
})
}
}
impl FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.project_in.forward(xs)?;
self.linear.forward(&xs)
}
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
#[derive(Debug)]
struct CrossAttention {
to_q: nn::Linear,
to_k: nn::Linear,
to_v: nn::Linear,
to_out: nn::Linear,
heads: usize,
scale: f64,
slice_size: Option<usize>,
span: tracing::Span,
span_attn: tracing::Span,
span_softmax: tracing::Span,
use_flash_attn: bool,
}
impl CrossAttention {
// Defaults should be heads = 8, dim_head = 64, context_dim = None
fn new(
vs: nn::VarBuilder,
query_dim: usize,
context_dim: Option<usize>,
heads: usize,
dim_head: usize,
slice_size: Option<usize>,
use_flash_attn: bool,
) -> Result<Self> {
let inner_dim = dim_head * heads;
let context_dim = context_dim.unwrap_or(query_dim);
let scale = 1.0 / f64::sqrt(dim_head as f64);
let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
let span = tracing::span!(tracing::Level::TRACE, "xa");
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
Ok(Self {
to_q,
to_k,
to_v,
to_out,
heads,
scale,
slice_size,
span,
span_attn,
span_softmax,
use_flash_attn,
})
}
fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
let (batch_size, seq_len, dim) = xs.dims3()?;
xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
.transpose(1, 2)?
.reshape((batch_size * self.heads, seq_len, dim / self.heads))
}
fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
let (batch_size, seq_len, dim) = xs.dims3()?;
xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
.transpose(1, 2)?
.reshape((batch_size / self.heads, seq_len, dim * self.heads))
}
fn sliced_attention(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
slice_size: usize,
) -> Result<Tensor> {
let batch_size_attention = query.dim(0)?;
let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
let in_dtype = query.dtype();
let query = query.to_dtype(DType::F32)?;
let key = key.to_dtype(DType::F32)?;
let value = value.to_dtype(DType::F32)?;
for i in 0..batch_size_attention / slice_size {
let start_idx = i * slice_size;
let end_idx = (i + 1) * slice_size;
let xs = query
.i(start_idx..end_idx)?
.matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?;
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
hidden_states.push(xs)
}
let hidden_states = Tensor::stack(&hidden_states, 0)?.to_dtype(in_dtype)?;
self.reshape_batch_dim_to_heads(&hidden_states)
}
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let xs = if self.use_flash_attn {
let init_dtype = query.dtype();
let q = query
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
let k = key
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
let v = value
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
flash_attn(&q, &k, &v, self.scale as f32, false)?
.transpose(1, 2)?
.squeeze(0)?
.to_dtype(init_dtype)?
} else {
let in_dtype = query.dtype();
let query = query.to_dtype(DType::F32)?;
let key = key.to_dtype(DType::F32)?;
let value = value.to_dtype(DType::F32)?;
let xs = query.matmul(&(key.t()? * self.scale)?)?;
let xs = {
let _enter = self.span_softmax.enter();
nn::ops::softmax_last_dim(&xs)?
};
xs.matmul(&value)?.to_dtype(in_dtype)?
};
self.reshape_batch_dim_to_heads(&xs)
}
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let query = self.to_q.forward(xs)?;
let context = context.unwrap_or(xs).contiguous()?;
let key = self.to_k.forward(&context)?;
let value = self.to_v.forward(&context)?;
let query = self.reshape_heads_to_batch_dim(&query)?;
let key = self.reshape_heads_to_batch_dim(&key)?;
let value = self.reshape_heads_to_batch_dim(&value)?;
let dim0 = query.dim(0)?;
let slice_size = self.slice_size.and_then(|slice_size| {
if dim0 < slice_size {
None
} else {
Some(slice_size)
}
});
let xs = match slice_size {
None => self.attention(&query, &key, &value)?,
Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?,
};
self.to_out.forward(&xs)
}
}
/// A basic Transformer block.
#[derive(Debug)]
struct BasicTransformerBlock {
attn1: CrossAttention,
ff: FeedForward,
attn2: CrossAttention,
norm1: nn::LayerNorm,
norm2: nn::LayerNorm,
norm3: nn::LayerNorm,
span: tracing::Span,
}
impl BasicTransformerBlock {
fn new(
vs: nn::VarBuilder,
dim: usize,
n_heads: usize,
d_head: usize,
context_dim: Option<usize>,
sliced_attention_size: Option<usize>,
use_flash_attn: bool,
) -> Result<Self> {
let attn1 = CrossAttention::new(
vs.pp("attn1"),
dim,
None,
n_heads,
d_head,
sliced_attention_size,
use_flash_attn,
)?;
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
let attn2 = CrossAttention::new(
vs.pp("attn2"),
dim,
context_dim,
n_heads,
d_head,
sliced_attention_size,
use_flash_attn,
)?;
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
let span = tracing::span!(tracing::Level::TRACE, "basic-transformer");
Ok(Self {
attn1,
ff,
attn2,
norm1,
norm2,
norm3,
span,
})
}
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
self.ff.forward(&self.norm3.forward(&xs)?)? + xs
}
}
#[derive(Debug, Clone, Copy)]
pub struct SpatialTransformerConfig {
pub depth: usize,
pub num_groups: usize,
pub context_dim: Option<usize>,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for SpatialTransformerConfig {
fn default() -> Self {
Self {
depth: 1,
num_groups: 32,
context_dim: None,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
enum Proj {
Conv2d(nn::Conv2d),
Linear(nn::Linear),
}
// Aka Transformer2DModel
#[derive(Debug)]
pub struct SpatialTransformer {
norm: nn::GroupNorm,
proj_in: Proj,
transformer_blocks: Vec<BasicTransformerBlock>,
proj_out: Proj,
span: tracing::Span,
pub config: SpatialTransformerConfig,
}
impl SpatialTransformer {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
n_heads: usize,
d_head: usize,
use_flash_attn: bool,
config: SpatialTransformerConfig,
) -> Result<Self> {
let inner_dim = n_heads * d_head;
let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?;
let proj_in = if config.use_linear_projection {
Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?)
} else {
Proj::Conv2d(nn::conv2d(
in_channels,
inner_dim,
1,
Default::default(),
vs.pp("proj_in"),
)?)
};
let mut transformer_blocks = vec![];
let vs_tb = vs.pp("transformer_blocks");
for index in 0..config.depth {
let tb = BasicTransformerBlock::new(
vs_tb.pp(&index.to_string()),
inner_dim,
n_heads,
d_head,
config.context_dim,
config.sliced_attention_size,
use_flash_attn,
)?;
transformer_blocks.push(tb)
}
let proj_out = if config.use_linear_projection {
Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?)
} else {
Proj::Conv2d(nn::conv2d(
inner_dim,
in_channels,
1,
Default::default(),
vs.pp("proj_out"),
)?)
};
let span = tracing::span!(tracing::Level::TRACE, "spatial-transformer");
Ok(Self {
norm,
proj_in,
transformer_blocks,
proj_out,
span,
config,
})
}
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let (batch, _channel, height, weight) = xs.dims4()?;
let residual = xs;
let xs = self.norm.forward(xs)?;
let (inner_dim, xs) = match &self.proj_in {
Proj::Conv2d(p) => {
let xs = p.forward(&xs)?;
let inner_dim = xs.dim(1)?;
let xs = xs
.transpose(1, 2)?
.t()?
.reshape((batch, height * weight, inner_dim))?;
(inner_dim, xs)
}
Proj::Linear(p) => {
let inner_dim = xs.dim(1)?;
let xs = xs
.transpose(1, 2)?
.t()?
.reshape((batch, height * weight, inner_dim))?;
(inner_dim, p.forward(&xs)?)
}
};
let mut xs = xs;
for block in self.transformer_blocks.iter() {
xs = block.forward(&xs, context)?
}
let xs = match &self.proj_out {
Proj::Conv2d(p) => p.forward(
&xs.reshape((batch, height, weight, inner_dim))?
.t()?
.transpose(1, 2)?,
)?,
Proj::Linear(p) => p
.forward(&xs)?
.reshape((batch, height, weight, inner_dim))?
.t()?
.transpose(1, 2)?,
};
xs + residual
}
}
/// Configuration for an attention block.
#[derive(Debug, Clone, Copy)]
pub struct AttentionBlockConfig {
pub num_head_channels: Option<usize>,
pub num_groups: usize,
pub rescale_output_factor: f64,
pub eps: f64,
}
impl Default for AttentionBlockConfig {
fn default() -> Self {
Self {
num_head_channels: None,
num_groups: 32,
rescale_output_factor: 1.,
eps: 1e-5,
}
}
}
#[derive(Debug)]
pub struct AttentionBlock {
group_norm: nn::GroupNorm,
query: nn::Linear,
key: nn::Linear,
value: nn::Linear,
proj_attn: nn::Linear,
channels: usize,
num_heads: usize,
span: tracing::Span,
config: AttentionBlockConfig,
}
impl AttentionBlock {
pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
let num_head_channels = config.num_head_channels.unwrap_or(channels);
let num_heads = channels / num_head_channels;
let group_norm =
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
let (q_path, k_path, v_path, out_path) = if vs.contains_tensor("to_q.weight") {
("to_q", "to_k", "to_v", "to_out.0")
} else {
("query", "key", "value", "proj_attn")
};
let query = nn::linear(channels, channels, vs.pp(q_path))?;
let key = nn::linear(channels, channels, vs.pp(k_path))?;
let value = nn::linear(channels, channels, vs.pp(v_path))?;
let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?;
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
Ok(Self {
group_norm,
query,
key,
value,
proj_attn,
channels,
num_heads,
span,
config,
})
}
fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> {
let (batch, t, h_times_d) = xs.dims3()?;
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
.transpose(1, 2)
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let in_dtype = xs.dtype();
let residual = xs;
let (batch, channel, height, width) = xs.dims4()?;
let xs = self
.group_norm
.forward(xs)?
.reshape((batch, channel, height * width))?
.transpose(1, 2)?;
let query_proj = self.query.forward(&xs)?;
let key_proj = self.key.forward(&xs)?;
let value_proj = self.value.forward(&xs)?;
let query_states = self
.transpose_for_scores(query_proj)?
.to_dtype(DType::F32)?;
let key_states = self.transpose_for_scores(key_proj)?.to_dtype(DType::F32)?;
let value_states = self
.transpose_for_scores(value_proj)?
.to_dtype(DType::F32)?;
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
let attention_scores =
// TODO: Check that this needs two multiplication by `scale`.
(query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
let xs = attention_probs.matmul(&value_states.contiguous()?)?;
let xs = xs.to_dtype(in_dtype)?;
let xs = xs.transpose(1, 2)?.contiguous()?;
let xs = xs.flatten_from(D::Minus2)?;
let xs = self
.proj_attn
.forward(&xs)?
.t()?
.reshape((batch, channel, height, width))?;
(xs + residual)? / self.config.rescale_output_factor
}
}

View File

@ -0,0 +1,339 @@
//! Contrastive Language-Image Pre-Training
//!
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
//! pairs of images with related texts.
//!
//! https://github.com/openai/CLIP
use candle::{DType, Device, Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug, Clone, Copy)]
pub enum Activation {
QuickGelu,
Gelu,
}
impl Activation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
Activation::Gelu => xs.gelu(),
}
}
}
#[derive(Debug, Clone)]
pub struct Config {
vocab_size: usize,
embed_dim: usize, // aka config.hidden_size
activation: Activation, // aka config.hidden_act
intermediate_size: usize,
pub max_position_embeddings: usize,
// The character to use for padding, use EOS when not set.
pub pad_with: Option<String>,
num_hidden_layers: usize,
num_attention_heads: usize,
#[allow(dead_code)]
projection_dim: usize,
}
impl Config {
// The config details can be found in the "text_config" section of this json file:
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
pub fn v1_5() -> Self {
Self {
vocab_size: 49408,
embed_dim: 768,
intermediate_size: 3072,
max_position_embeddings: 77,
pad_with: None,
num_hidden_layers: 12,
num_attention_heads: 12,
projection_dim: 768,
activation: Activation::QuickGelu,
}
}
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json
pub fn v2_1() -> Self {
Self {
vocab_size: 49408,
embed_dim: 1024,
intermediate_size: 4096,
max_position_embeddings: 77,
pad_with: Some("!".to_string()),
num_hidden_layers: 23,
num_attention_heads: 16,
projection_dim: 512,
activation: Activation::Gelu,
}
}
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder/config.json
pub fn sdxl() -> Self {
Self {
vocab_size: 49408,
embed_dim: 768,
intermediate_size: 3072,
max_position_embeddings: 77,
pad_with: Some("!".to_string()),
num_hidden_layers: 12,
num_attention_heads: 12,
projection_dim: 768,
activation: Activation::QuickGelu,
}
}
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder_2/config.json
pub fn sdxl2() -> Self {
Self {
vocab_size: 49408,
embed_dim: 1280,
intermediate_size: 5120,
max_position_embeddings: 77,
pad_with: Some("!".to_string()),
num_hidden_layers: 32,
num_attention_heads: 20,
projection_dim: 1280,
activation: Activation::Gelu,
}
}
}
// CLIP Text Model
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py
#[derive(Debug)]
struct ClipTextEmbeddings {
token_embedding: candle_nn::Embedding,
position_embedding: candle_nn::Embedding,
position_ids: Tensor,
}
impl ClipTextEmbeddings {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let token_embedding =
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
let position_embedding = candle_nn::embedding(
c.max_position_embeddings,
c.embed_dim,
vs.pp("position_embedding"),
)?;
let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
Ok(ClipTextEmbeddings {
token_embedding,
position_embedding,
position_ids,
})
}
}
impl ClipTextEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let token_embedding = self.token_embedding.forward(xs)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
token_embedding.broadcast_add(&position_embedding)
}
}
#[derive(Debug)]
struct ClipAttention {
k_proj: candle_nn::Linear,
v_proj: candle_nn::Linear,
q_proj: candle_nn::Linear,
out_proj: candle_nn::Linear,
head_dim: usize,
scale: f64,
num_attention_heads: usize,
}
impl ClipAttention {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let embed_dim = c.embed_dim;
let num_attention_heads = c.num_attention_heads;
let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
let head_dim = embed_dim / num_attention_heads;
let scale = (head_dim as f64).powf(-0.5);
Ok(ClipAttention {
k_proj,
v_proj,
q_proj,
out_proj,
head_dim,
scale,
num_attention_heads,
})
}
fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()
}
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
let in_dtype = xs.dtype();
let (bsz, seq_len, embed_dim) = xs.dims3()?;
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
let query_states = self
.shape(&query_states, seq_len, bsz)?
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let key_states = self
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let value_states = self
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let src_len = key_states.dim(1)?;
let attn_weights = attn_weights
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
.broadcast_add(causal_attention_mask)?;
let attn_weights =
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
let attn_output = attn_output
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
.transpose(1, 2)?
.reshape((bsz, seq_len, embed_dim))?;
self.out_proj.forward(&attn_output)
}
}
#[derive(Debug)]
struct ClipMlp {
fc1: candle_nn::Linear,
fc2: candle_nn::Linear,
activation: Activation,
}
impl ClipMlp {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?;
let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?;
Ok(ClipMlp {
fc1,
fc2,
activation: c.activation,
})
}
}
impl ClipMlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?;
self.fc2.forward(&self.activation.forward(&xs)?)
}
}
#[derive(Debug)]
struct ClipEncoderLayer {
self_attn: ClipAttention,
layer_norm1: candle_nn::LayerNorm,
mlp: ClipMlp,
layer_norm2: candle_nn::LayerNorm,
}
impl ClipEncoderLayer {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?;
let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?;
Ok(ClipEncoderLayer {
self_attn,
layer_norm1,
mlp,
layer_norm2,
})
}
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = self.layer_norm1.forward(xs)?;
let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self.layer_norm2.forward(&xs)?;
let xs = self.mlp.forward(&xs)?;
xs + residual
}
}
#[derive(Debug)]
struct ClipEncoder {
layers: Vec<ClipEncoderLayer>,
}
impl ClipEncoder {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let vs = vs.pp("layers");
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
for index in 0..c.num_hidden_layers {
let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
layers.push(layer)
}
Ok(ClipEncoder { layers })
}
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?;
}
Ok(xs)
}
}
/// A CLIP transformer based model.
#[derive(Debug)]
pub struct ClipTextTransformer {
embeddings: ClipTextEmbeddings,
encoder: ClipEncoder,
final_layer_norm: candle_nn::LayerNorm,
}
impl ClipTextTransformer {
pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let vs = vs.pp("text_model");
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
Ok(ClipTextTransformer {
embeddings,
encoder,
final_layer_norm,
})
}
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
mask.broadcast_as((bsz, seq_len, seq_len))
}
}
impl ClipTextTransformer {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (bsz, seq_len) = xs.dims2()?;
let xs = self.embeddings.forward(xs)?;
let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
self.final_layer_norm.forward(&xs)
}
}

View File

@ -0,0 +1,180 @@
//! # Denoising Diffusion Implicit Models
//!
//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler
//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM
//! generative process is the reverse of a Markovian process, DDIM generalizes
//! this to non-Markovian guidance.
//!
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
//! https://arxiv.org/abs/2010.02502
use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
use candle::{Result, Tensor};
/// The configuration for the DDIM scheduler.
#[derive(Debug, Clone, Copy)]
pub struct DDIMSchedulerConfig {
/// The value of beta at the beginning of training.
pub beta_start: f64,
/// The value of beta at the end of training.
pub beta_end: f64,
/// How beta evolved during training.
pub beta_schedule: BetaSchedule,
/// The amount of noise to be added at each step.
pub eta: f64,
/// Adjust the indexes of the inference schedule by this value.
pub steps_offset: usize,
/// prediction type of the scheduler function, one of `epsilon` (predicting
/// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
/// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
pub prediction_type: PredictionType,
/// number of diffusion steps used to train the model
pub train_timesteps: usize,
}
impl Default for DDIMSchedulerConfig {
fn default() -> Self {
Self {
beta_start: 0.00085f64,
beta_end: 0.012f64,
beta_schedule: BetaSchedule::ScaledLinear,
eta: 0.,
steps_offset: 1,
prediction_type: PredictionType::Epsilon,
train_timesteps: 1000,
}
}
}
/// The DDIM scheduler.
#[derive(Debug, Clone)]
pub struct DDIMScheduler {
timesteps: Vec<usize>,
alphas_cumprod: Vec<f64>,
step_ratio: usize,
init_noise_sigma: f64,
pub config: DDIMSchedulerConfig,
}
// clip_sample: False, set_alpha_to_one: False
impl DDIMScheduler {
/// Creates a new DDIM scheduler given the number of steps to be
/// used for inference as well as the number of steps that was used
/// during training.
pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
let step_ratio = config.train_timesteps / inference_steps;
let timesteps: Vec<usize> = (0..(inference_steps))
.map(|s| s * step_ratio + config.steps_offset)
.rev()
.collect();
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => super::utils::linspace(
config.beta_start.sqrt(),
config.beta_end.sqrt(),
config.train_timesteps,
)?
.sqr()?,
BetaSchedule::Linear => {
super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
}
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
};
let betas = betas.to_vec1::<f64>()?;
let mut alphas_cumprod = Vec::with_capacity(betas.len());
for &beta in betas.iter() {
let alpha = 1.0 - beta;
alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
}
Ok(Self {
alphas_cumprod,
timesteps,
step_ratio,
init_noise_sigma: 1.,
config,
})
}
pub fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
/// Ensures interchangeability with schedulers that need to scale the denoising model input
/// depending on the current timestep.
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
Ok(sample)
}
/// Performs a backward step during inference.
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
timestep
};
// https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
let prev_timestep = if timestep > self.step_ratio {
timestep - self.step_ratio
} else {
0
};
let alpha_prod_t = self.alphas_cumprod[timestep];
let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
let beta_prod_t = 1. - alpha_prod_t;
let beta_prod_t_prev = 1. - alpha_prod_t_prev;
let (pred_original_sample, pred_epsilon) = match self.config.prediction_type {
PredictionType::Epsilon => {
let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)?
* (1. / alpha_prod_t.sqrt()))?;
(pred_original_sample, model_output.clone())
}
PredictionType::VPrediction => {
let pred_original_sample =
((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?;
let pred_epsilon =
((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?;
(pred_original_sample, pred_epsilon)
}
PredictionType::Sample => {
let pred_original_sample = model_output.clone();
let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())?
* (1. / beta_prod_t.sqrt()))?;
(pred_original_sample, pred_epsilon)
}
};
let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev);
let std_dev_t = self.config.eta * variance.sqrt();
let pred_sample_direction =
(pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?;
let prev_sample =
((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?;
if self.config.eta > 0. {
&prev_sample
+ Tensor::randn(
0f32,
std_dev_t as f32,
prev_sample.shape(),
prev_sample.device(),
)?
} else {
Ok(prev_sample)
}
}
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
timestep
};
let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
}
pub fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
}

View File

@ -0,0 +1,65 @@
use candle::{Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug)]
pub struct TimestepEmbedding {
linear_1: nn::Linear,
linear_2: nn::Linear,
}
impl TimestepEmbedding {
// act_fn: "silu"
pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
Ok(Self { linear_1, linear_2 })
}
}
impl TimestepEmbedding {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
self.linear_2.forward(&xs)
}
}
#[derive(Debug)]
pub struct Timesteps {
num_channels: usize,
flip_sin_to_cos: bool,
downscale_freq_shift: f64,
}
impl Timesteps {
pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
Self {
num_channels,
flip_sin_to_cos,
downscale_freq_shift,
}
}
}
impl Timesteps {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let half_dim = (self.num_channels / 2) as u32;
let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
* -f64::ln(10000.))?;
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
let emb = exponent.exp()?.to_dtype(xs.dtype())?;
// emb = timesteps[:, None].float() * emb[None, :]
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
let (cos, sin) = (emb.cos()?, emb.sin()?);
let emb = if self.flip_sin_to_cos {
Tensor::cat(&[&cos, &sin], D::Minus1)?
} else {
Tensor::cat(&[&sin, &cos], D::Minus1)?
};
if self.num_channels % 2 == 1 {
emb.pad_with_zeros(D::Minus2, 0, 1)
} else {
Ok(emb)
}
}
}

View File

@ -0,0 +1,302 @@
pub mod attention;
pub mod clip;
pub mod ddim;
pub mod embeddings;
pub mod resnet;
pub mod schedulers;
pub mod unet_2d;
pub mod unet_2d_blocks;
pub mod utils;
pub mod vae;
use candle::{DType, Device, Result};
use candle_nn as nn;
#[derive(Clone, Debug)]
pub struct StableDiffusionConfig {
pub width: usize,
pub height: usize,
pub clip: clip::Config,
pub clip2: Option<clip::Config>,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: ddim::DDIMSchedulerConfig,
}
impl StableDiffusionConfig {
pub fn v1_5(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, Some(1), 8),
bc(640, Some(1), 8),
bc(1280, Some(1), 8),
bc(1280, None, 8),
],
center_input_sample: false,
cross_attention_dim: 768,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: false,
};
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
512
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
512
};
Self {
width,
height,
clip: clip::Config::v1_5(),
clip2: None,
autoencoder,
scheduler: Default::default(),
unet,
}
}
fn v2_1_(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
prediction_type: schedulers::PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, Some(1), 5),
bc(640, Some(1), 10),
bc(1280, Some(1), 20),
bc(1280, None, 20),
],
center_input_sample: false,
cross_attention_dim: 1024,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
768
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
768
};
Self {
width,
height,
clip: clip::Config::v2_1(),
clip2: None,
autoencoder,
scheduler,
unet,
}
}
pub fn v2_1(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json
Self::v2_1_(
sliced_attention_size,
height,
width,
schedulers::PredictionType::VPrediction,
)
}
fn sdxl_(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
prediction_type: schedulers::PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, None, 5),
bc(640, Some(2), 10),
bc(1280, Some(10), 20),
],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
1024
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
1024
};
Self {
width,
height,
clip: clip::Config::sdxl(),
clip2: Some(clip::Config::sdxl2()),
autoencoder,
scheduler,
unet,
}
}
pub fn sdxl(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
Self::sdxl_(
sliced_attention_size,
height,
width,
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
schedulers::PredictionType::Epsilon,
)
}
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
device: &Device,
dtype: DType,
) -> Result<vae::AutoEncoderKL> {
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
let weights = weights.deserialize()?;
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
Ok(autoencoder)
}
pub fn build_unet<P: AsRef<std::path::Path>>(
&self,
unet_weights: P,
device: &Device,
in_channels: usize,
use_flash_attn: bool,
dtype: DType,
) -> Result<unet_2d::UNet2DConditionModel> {
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
let weights = weights.deserialize()?;
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
let unet = unet_2d::UNet2DConditionModel::new(
vs_unet,
in_channels,
4,
use_flash_attn,
self.unet.clone(),
)?;
Ok(unet)
}
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
}
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
clip: &clip::Config,
clip_weights: P,
device: &Device,
dtype: DType,
) -> Result<clip::ClipTextTransformer> {
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
let weights = weights.deserialize()?;
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
let text_model = clip::ClipTextTransformer::new(vs, clip)?;
Ok(text_model)
}

View File

@ -0,0 +1,138 @@
//! ResNet Building Blocks
//!
//! Some Residual Network blocks used in UNet models.
//!
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
//! https://arxiv.org/abs/1512.03385
use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
/// Configuration for a ResNet block.
#[derive(Debug, Clone, Copy)]
pub struct ResnetBlock2DConfig {
/// The number of output channels, defaults to the number of input channels.
pub out_channels: Option<usize>,
pub temb_channels: Option<usize>,
/// The number of groups to use in group normalization.
pub groups: usize,
pub groups_out: Option<usize>,
/// The epsilon to be used in the group normalization operations.
pub eps: f64,
/// Whether to use a 2D convolution in the skip connection. When using None,
/// such a convolution is used if the number of input channels is different from
/// the number of output channels.
pub use_in_shortcut: Option<bool>,
// non_linearity: silu
/// The final output is scaled by dividing by this value.
pub output_scale_factor: f64,
}
impl Default for ResnetBlock2DConfig {
fn default() -> Self {
Self {
out_channels: None,
temb_channels: Some(512),
groups: 32,
groups_out: None,
eps: 1e-6,
use_in_shortcut: None,
output_scale_factor: 1.,
}
}
}
#[derive(Debug)]
pub struct ResnetBlock2D {
norm1: nn::GroupNorm,
conv1: Conv2d,
norm2: nn::GroupNorm,
conv2: Conv2d,
time_emb_proj: Option<nn::Linear>,
conv_shortcut: Option<Conv2d>,
span: tracing::Span,
config: ResnetBlock2DConfig,
}
impl ResnetBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
config: ResnetBlock2DConfig,
) -> Result<Self> {
let out_channels = config.out_channels.unwrap_or(in_channels);
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 1,
groups: 1,
dilation: 1,
};
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
let groups_out = config.groups_out.unwrap_or(config.groups);
let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
let conv2 = conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
let use_in_shortcut = config
.use_in_shortcut
.unwrap_or(in_channels != out_channels);
let conv_shortcut = if use_in_shortcut {
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 0,
groups: 1,
dilation: 1,
};
Some(conv2d(
in_channels,
out_channels,
1,
conv_cfg,
vs.pp("conv_shortcut"),
)?)
} else {
None
};
let time_emb_proj = match config.temb_channels {
None => None,
Some(temb_channels) => Some(nn::linear(
temb_channels,
out_channels,
vs.pp("time_emb_proj"),
)?),
};
let span = tracing::span!(tracing::Level::TRACE, "resnet2d");
Ok(Self {
norm1,
conv1,
norm2,
conv2,
time_emb_proj,
span,
config,
conv_shortcut,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let shortcut_xs = match &self.conv_shortcut {
Some(conv_shortcut) => conv_shortcut.forward(xs)?,
None => xs.clone(),
};
let xs = self.norm1.forward(xs)?;
let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
let xs = match (temb, &self.time_emb_proj) {
(Some(temb), Some(time_emb_proj)) => time_emb_proj
.forward(&nn::ops::silu(temb)?)?
.unsqueeze(D::Minus1)?
.unsqueeze(D::Minus1)?
.broadcast_add(&xs)?,
_ => xs,
};
let xs = self
.conv2
.forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
(shortcut_xs + xs)? / self.config.output_scale_factor
}
}

View File

@ -0,0 +1,45 @@
#![allow(dead_code)]
//! # Diffusion pipelines and models
//!
//! Noise schedulers can be used to set the trade-off between
//! inference speed and quality.
use candle::{Result, Tensor};
/// This represents how beta ranges from its minimum value to the maximum
/// during training.
#[derive(Debug, Clone, Copy)]
pub enum BetaSchedule {
/// Linear interpolation.
Linear,
/// Linear interpolation of the square root of beta.
ScaledLinear,
/// Glide cosine schedule
SquaredcosCapV2,
}
#[derive(Debug, Clone, Copy)]
pub enum PredictionType {
Epsilon,
VPrediction,
Sample,
}
/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
/// `(1-beta)` over time from `t = [0,1]`.
///
/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)`
/// up to that part of the diffusion process.
pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> {
let alpha_bar = |time_step: usize| {
f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2)
};
let mut betas = Vec::with_capacity(num_diffusion_timesteps);
for i in 0..num_diffusion_timesteps {
let t1 = i / num_diffusion_timesteps;
let t2 = (i + 1) / num_diffusion_timesteps;
betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta));
}
let betas_len = betas.len();
Tensor::from_vec(betas, betas_len, &candle::Device::Cpu)
}

View File

@ -0,0 +1,401 @@
//! 2D UNet Denoising Models
//!
//! The 2D Unet models take as input a noisy sample and the current diffusion
//! timestep and return a denoised version of the input.
use super::embeddings::{TimestepEmbedding, Timesteps};
use super::unet_2d_blocks::*;
use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug, Clone, Copy)]
pub struct BlockConfig {
pub out_channels: usize,
/// When `None` no cross-attn is used, when `Some(d)` then cross-attn is used and `d` is the
/// number of transformer blocks to be used.
pub use_cross_attn: Option<usize>,
pub attention_head_dim: usize,
}
#[derive(Debug, Clone)]
pub struct UNet2DConditionModelConfig {
pub center_input_sample: bool,
pub flip_sin_to_cos: bool,
pub freq_shift: f64,
pub blocks: Vec<BlockConfig>,
pub layers_per_block: usize,
pub downsample_padding: usize,
pub mid_block_scale_factor: f64,
pub norm_num_groups: usize,
pub norm_eps: f64,
pub cross_attention_dim: usize,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for UNet2DConditionModelConfig {
fn default() -> Self {
Self {
center_input_sample: false,
flip_sin_to_cos: true,
freq_shift: 0.,
blocks: vec![
BlockConfig {
out_channels: 320,
use_cross_attn: Some(1),
attention_head_dim: 8,
},
BlockConfig {
out_channels: 640,
use_cross_attn: Some(1),
attention_head_dim: 8,
},
BlockConfig {
out_channels: 1280,
use_cross_attn: Some(1),
attention_head_dim: 8,
},
BlockConfig {
out_channels: 1280,
use_cross_attn: None,
attention_head_dim: 8,
},
],
layers_per_block: 2,
downsample_padding: 1,
mid_block_scale_factor: 1.,
norm_num_groups: 32,
norm_eps: 1e-5,
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub(crate) enum UNetDownBlock {
Basic(DownBlock2D),
CrossAttn(CrossAttnDownBlock2D),
}
#[derive(Debug)]
enum UNetUpBlock {
Basic(UpBlock2D),
CrossAttn(CrossAttnUpBlock2D),
}
#[derive(Debug)]
pub struct UNet2DConditionModel {
conv_in: Conv2d,
time_proj: Timesteps,
time_embedding: TimestepEmbedding,
down_blocks: Vec<UNetDownBlock>,
mid_block: UNetMidBlock2DCrossAttn,
up_blocks: Vec<UNetUpBlock>,
conv_norm_out: nn::GroupNorm,
conv_out: Conv2d,
span: tracing::Span,
config: UNet2DConditionModelConfig,
}
impl UNet2DConditionModel {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
use_flash_attn: bool,
config: UNet2DConditionModelConfig,
) -> Result<Self> {
let n_blocks = config.blocks.len();
let b_channels = config.blocks[0].out_channels;
let bl_channels = config.blocks.last().unwrap().out_channels;
let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
let time_embed_dim = b_channels * 4;
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
let time_embedding =
TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
let vs_db = vs.pp("down_blocks");
let down_blocks = (0..n_blocks)
.map(|i| {
let BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
} = config.blocks[i];
// Enable automatic attention slicing if the config sliced_attention_size is set to 0.
let sliced_attention_size = match config.sliced_attention_size {
Some(0) => Some(attention_head_dim / 2),
_ => config.sliced_attention_size,
};
let in_channels = if i > 0 {
config.blocks[i - 1].out_channels
} else {
b_channels
};
let db_cfg = DownBlock2DConfig {
num_layers: config.layers_per_block,
resnet_eps: config.norm_eps,
resnet_groups: config.norm_num_groups,
add_downsample: i < n_blocks - 1,
downsample_padding: config.downsample_padding,
..Default::default()
};
if let Some(transformer_layers_per_block) = use_cross_attn {
let config = CrossAttnDownBlock2DConfig {
downblock: db_cfg,
attn_num_head_channels: attention_head_dim,
cross_attention_dim: config.cross_attention_dim,
sliced_attention_size,
use_linear_projection: config.use_linear_projection,
transformer_layers_per_block,
};
let block = CrossAttnDownBlock2D::new(
vs_db.pp(&i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
use_flash_attn,
config,
)?;
Ok(UNetDownBlock::CrossAttn(block))
} else {
let block = DownBlock2D::new(
vs_db.pp(&i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
db_cfg,
)?;
Ok(UNetDownBlock::Basic(block))
}
})
.collect::<Result<Vec<_>>>()?;
// https://github.com/huggingface/diffusers/blob/a76f2ad538e73b34d5fe7be08c8eb8ab38c7e90c/src/diffusers/models/unet_2d_condition.py#L462
let mid_transformer_layers_per_block = match config.blocks.last() {
None => 1,
Some(block) => block.use_cross_attn.unwrap_or(1),
};
let mid_cfg = UNetMidBlock2DCrossAttnConfig {
resnet_eps: config.norm_eps,
output_scale_factor: config.mid_block_scale_factor,
cross_attn_dim: config.cross_attention_dim,
attn_num_head_channels: bl_attention_head_dim,
resnet_groups: Some(config.norm_num_groups),
use_linear_projection: config.use_linear_projection,
transformer_layers_per_block: mid_transformer_layers_per_block,
..Default::default()
};
let mid_block = UNetMidBlock2DCrossAttn::new(
vs.pp("mid_block"),
bl_channels,
Some(time_embed_dim),
use_flash_attn,
mid_cfg,
)?;
let vs_ub = vs.pp("up_blocks");
let up_blocks = (0..n_blocks)
.map(|i| {
let BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
} = config.blocks[n_blocks - 1 - i];
// Enable automatic attention slicing if the config sliced_attention_size is set to 0.
let sliced_attention_size = match config.sliced_attention_size {
Some(0) => Some(attention_head_dim / 2),
_ => config.sliced_attention_size,
};
let prev_out_channels = if i > 0 {
config.blocks[n_blocks - i].out_channels
} else {
bl_channels
};
let in_channels = {
let index = if i == n_blocks - 1 {
0
} else {
n_blocks - i - 2
};
config.blocks[index].out_channels
};
let ub_cfg = UpBlock2DConfig {
num_layers: config.layers_per_block + 1,
resnet_eps: config.norm_eps,
resnet_groups: config.norm_num_groups,
add_upsample: i < n_blocks - 1,
..Default::default()
};
if let Some(transformer_layers_per_block) = use_cross_attn {
let config = CrossAttnUpBlock2DConfig {
upblock: ub_cfg,
attn_num_head_channels: attention_head_dim,
cross_attention_dim: config.cross_attention_dim,
sliced_attention_size,
use_linear_projection: config.use_linear_projection,
transformer_layers_per_block,
};
let block = CrossAttnUpBlock2D::new(
vs_ub.pp(&i.to_string()),
in_channels,
prev_out_channels,
out_channels,
Some(time_embed_dim),
use_flash_attn,
config,
)?;
Ok(UNetUpBlock::CrossAttn(block))
} else {
let block = UpBlock2D::new(
vs_ub.pp(&i.to_string()),
in_channels,
prev_out_channels,
out_channels,
Some(time_embed_dim),
ub_cfg,
)?;
Ok(UNetUpBlock::Basic(block))
}
})
.collect::<Result<Vec<_>>>()?;
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
b_channels,
config.norm_eps,
vs.pp("conv_norm_out"),
)?;
let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
let span = tracing::span!(tracing::Level::TRACE, "unet2d");
Ok(Self {
conv_in,
time_proj,
time_embedding,
down_blocks,
mid_block,
up_blocks,
conv_norm_out,
conv_out,
span,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
timestep: f64,
encoder_hidden_states: &Tensor,
) -> Result<Tensor> {
let _enter = self.span.enter();
self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
}
pub fn forward_with_additional_residuals(
&self,
xs: &Tensor,
timestep: f64,
encoder_hidden_states: &Tensor,
down_block_additional_residuals: Option<&[Tensor]>,
mid_block_additional_residual: Option<&Tensor>,
) -> Result<Tensor> {
let (bsize, _channels, height, width) = xs.dims4()?;
let device = xs.device();
let n_blocks = self.config.blocks.len();
let num_upsamplers = n_blocks - 1;
let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
let forward_upsample_size =
height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
// 0. center input if necessary
let xs = if self.config.center_input_sample {
((xs * 2.0)? - 1.0)?
} else {
xs.clone()
};
// 1. time
let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?;
let emb = self.time_proj.forward(&emb)?;
let emb = self.time_embedding.forward(&emb)?;
// 2. pre-process
let xs = self.conv_in.forward(&xs)?;
// 3. down
let mut down_block_res_xs = vec![xs.clone()];
let mut xs = xs;
for down_block in self.down_blocks.iter() {
let (_xs, res_xs) = match down_block {
UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
UNetDownBlock::CrossAttn(b) => {
b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
}
};
down_block_res_xs.extend(res_xs);
xs = _xs;
}
let new_down_block_res_xs =
if let Some(down_block_additional_residuals) = down_block_additional_residuals {
let mut v = vec![];
// A previous version of this code had a bug because of the addition being made
// in place via += hence modifying the input of the mid block.
for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
v.push((&down_block_res_xs[i] + residuals)?)
}
v
} else {
down_block_res_xs
};
let mut down_block_res_xs = new_down_block_res_xs;
// 4. mid
let xs = self
.mid_block
.forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
let xs = match mid_block_additional_residual {
None => xs,
Some(m) => (m + xs)?,
};
// 5. up
let mut xs = xs;
let mut upsample_size = None;
for (i, up_block) in self.up_blocks.iter().enumerate() {
let n_resnets = match up_block {
UNetUpBlock::Basic(b) => b.resnets.len(),
UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
};
let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
if i < n_blocks - 1 && forward_upsample_size {
let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
upsample_size = Some((h, w))
}
xs = match up_block {
UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
UNetUpBlock::CrossAttn(b) => b.forward(
&xs,
&res_xs,
Some(&emb),
upsample_size,
Some(encoder_hidden_states),
)?,
};
}
// 6. post-process
let xs = self.conv_norm_out.forward(&xs)?;
let xs = nn::ops::silu(&xs)?;
self.conv_out.forward(&xs)
}
}

View File

@ -0,0 +1,868 @@
//! 2D UNet Building Blocks
//!
use super::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
struct Downsample2D {
conv: Option<Conv2d>,
padding: usize,
span: tracing::Span,
}
impl Downsample2D {
fn new(
vs: nn::VarBuilder,
in_channels: usize,
use_conv: bool,
out_channels: usize,
padding: usize,
) -> Result<Self> {
let conv = if use_conv {
let config = nn::Conv2dConfig {
stride: 2,
padding,
..Default::default()
};
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Some(conv)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "downsample2d");
Ok(Self {
conv,
padding,
span,
})
}
}
impl Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
match &self.conv {
None => xs.avg_pool2d(2),
Some(conv) => {
if self.padding == 0 {
let xs = xs
.pad_with_zeros(D::Minus1, 0, 1)?
.pad_with_zeros(D::Minus2, 0, 1)?;
conv.forward(&xs)
} else {
conv.forward(xs)
}
}
}
}
}
// This does not support the conv-transpose mode.
#[derive(Debug)]
struct Upsample2D {
conv: Conv2d,
span: tracing::Span,
}
impl Upsample2D {
fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
let config = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
let span = tracing::span!(tracing::Level::TRACE, "upsample2d");
Ok(Self { conv, span })
}
}
impl Upsample2D {
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = match size {
None => {
let (_bsize, _channels, h, w) = xs.dims4()?;
xs.upsample_nearest2d(2 * h, 2 * w)?
}
Some((h, w)) => xs.upsample_nearest2d(h, w)?,
};
self.conv.forward(&xs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct DownEncoderBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_downsample: bool,
pub downsample_padding: usize,
}
impl Default for DownEncoderBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_downsample: true,
downsample_padding: 1,
}
}
}
#[derive(Debug)]
pub struct DownEncoderBlock2D {
resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>,
span: tracing::Span,
pub config: DownEncoderBlock2DConfig,
}
impl DownEncoderBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: DownEncoderBlock2DConfig,
) -> Result<Self> {
let resnets: Vec<_> = {
let vs = vs.pp("resnets");
let conv_cfg = ResnetBlock2DConfig {
eps: config.resnet_eps,
out_channels: Some(out_channels),
groups: config.resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels: None,
..Default::default()
};
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
let downsampler = if config.add_downsample {
let downsample = Downsample2D::new(
vs.pp("downsamplers").pp("0"),
out_channels,
true,
out_channels,
config.downsample_padding,
)?;
Some(downsample)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "down-enc2d");
Ok(Self {
resnets,
downsampler,
span,
config,
})
}
}
impl DownEncoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)?
}
match &self.downsampler {
Some(downsampler) => downsampler.forward(&xs),
None => Ok(xs),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct UpDecoderBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_upsample: bool,
}
impl Default for UpDecoderBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_upsample: true,
}
}
}
#[derive(Debug)]
pub struct UpDecoderBlock2D {
resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>,
span: tracing::Span,
pub config: UpDecoderBlock2DConfig,
}
impl UpDecoderBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: UpDecoderBlock2DConfig,
) -> Result<Self> {
let resnets: Vec<_> = {
let vs = vs.pp("resnets");
let conv_cfg = ResnetBlock2DConfig {
out_channels: Some(out_channels),
eps: config.resnet_eps,
groups: config.resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels: None,
..Default::default()
};
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
let upsampler = if config.add_upsample {
let upsample =
Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
Some(upsample)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "up-dec2d");
Ok(Self {
resnets,
upsampler,
span,
config,
})
}
}
impl UpDecoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)?
}
match &self.upsampler {
Some(upsampler) => upsampler.forward(&xs, None),
None => Ok(xs),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct UNetMidBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: Option<usize>,
pub attn_num_head_channels: Option<usize>,
// attention_type "default"
pub output_scale_factor: f64,
}
impl Default for UNetMidBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: Some(32),
attn_num_head_channels: Some(1),
output_scale_factor: 1.,
}
}
}
#[derive(Debug)]
pub struct UNetMidBlock2D {
resnet: ResnetBlock2D,
attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
span: tracing::Span,
pub config: UNetMidBlock2DConfig,
}
impl UNetMidBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
temb_channels: Option<usize>,
config: UNetMidBlock2DConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let vs_attns = vs.pp("attentions");
let resnet_groups = config
.resnet_groups
.unwrap_or_else(|| usize::min(in_channels / 4, 32));
let resnet_cfg = ResnetBlock2DConfig {
eps: config.resnet_eps,
groups: resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels,
..Default::default()
};
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
let attn_cfg = AttentionBlockConfig {
num_head_channels: config.attn_num_head_channels,
num_groups: resnet_groups,
rescale_output_factor: config.output_scale_factor,
eps: config.resnet_eps,
};
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
let resnet = ResnetBlock2D::new(
vs_resnets.pp(&(index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
attn_resnets.push((attn, resnet))
}
let span = tracing::span!(tracing::Level::TRACE, "mid2d");
Ok(Self {
resnet,
attn_resnets,
span,
config,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs)?, temb)?
}
Ok(xs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct UNetMidBlock2DCrossAttnConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: Option<usize>,
pub attn_num_head_channels: usize,
// attention_type "default"
pub output_scale_factor: f64,
pub cross_attn_dim: usize,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
pub transformer_layers_per_block: usize,
}
impl Default for UNetMidBlock2DCrossAttnConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: Some(32),
attn_num_head_channels: 1,
output_scale_factor: 1.,
cross_attn_dim: 1280,
sliced_attention_size: None, // Sliced attention disabled
use_linear_projection: false,
transformer_layers_per_block: 1,
}
}
}
#[derive(Debug)]
pub struct UNetMidBlock2DCrossAttn {
resnet: ResnetBlock2D,
attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
span: tracing::Span,
pub config: UNetMidBlock2DCrossAttnConfig,
}
impl UNetMidBlock2DCrossAttn {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
temb_channels: Option<usize>,
use_flash_attn: bool,
config: UNetMidBlock2DCrossAttnConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let vs_attns = vs.pp("attentions");
let resnet_groups = config
.resnet_groups
.unwrap_or_else(|| usize::min(in_channels / 4, 32));
let resnet_cfg = ResnetBlock2DConfig {
eps: config.resnet_eps,
groups: resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels,
..Default::default()
};
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
let n_heads = config.attn_num_head_channels;
let attn_cfg = SpatialTransformerConfig {
depth: config.transformer_layers_per_block,
num_groups: resnet_groups,
context_dim: Some(config.cross_attn_dim),
sliced_attention_size: config.sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
let attn = SpatialTransformer::new(
vs_attns.pp(&index.to_string()),
in_channels,
n_heads,
in_channels / n_heads,
use_flash_attn,
attn_cfg,
)?;
let resnet = ResnetBlock2D::new(
vs_resnets.pp(&(index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
attn_resnets.push((attn, resnet))
}
let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d");
Ok(Self {
resnet,
attn_resnets,
span,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
}
Ok(xs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct DownBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
// resnet_time_scale_shift: "default"
// resnet_act_fn: "swish"
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_downsample: bool,
pub downsample_padding: usize,
}
impl Default for DownBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_downsample: true,
downsample_padding: 1,
}
}
}
#[derive(Debug)]
pub struct DownBlock2D {
resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>,
span: tracing::Span,
pub config: DownBlock2DConfig,
}
impl DownBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: DownBlock2DConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let resnet_cfg = ResnetBlock2DConfig {
out_channels: Some(out_channels),
eps: config.resnet_eps,
output_scale_factor: config.output_scale_factor,
temb_channels,
..Default::default()
};
let resnets = (0..config.num_layers)
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let downsampler = if config.add_downsample {
let downsampler = Downsample2D::new(
vs.pp("downsamplers").pp("0"),
out_channels,
true,
out_channels,
config.downsample_padding,
)?;
Some(downsampler)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "down2d");
Ok(Self {
resnets,
downsampler,
span,
config,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
let _enter = self.span.enter();
let mut xs = xs.clone();
let mut output_states = vec![];
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, temb)?;
output_states.push(xs.clone());
}
let xs = match &self.downsampler {
Some(downsampler) => {
let xs = downsampler.forward(&xs)?;
output_states.push(xs.clone());
xs
}
None => xs,
};
Ok((xs, output_states))
}
}
#[derive(Debug, Clone, Copy)]
pub struct CrossAttnDownBlock2DConfig {
pub downblock: DownBlock2DConfig,
pub attn_num_head_channels: usize,
pub cross_attention_dim: usize,
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
pub transformer_layers_per_block: usize,
}
impl Default for CrossAttnDownBlock2DConfig {
fn default() -> Self {
Self {
downblock: Default::default(),
attn_num_head_channels: 1,
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
transformer_layers_per_block: 1,
}
}
}
#[derive(Debug)]
pub struct CrossAttnDownBlock2D {
downblock: DownBlock2D,
attentions: Vec<SpatialTransformer>,
span: tracing::Span,
pub config: CrossAttnDownBlock2DConfig,
}
impl CrossAttnDownBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
use_flash_attn: bool,
config: CrossAttnDownBlock2DConfig,
) -> Result<Self> {
let downblock = DownBlock2D::new(
vs.clone(),
in_channels,
out_channels,
temb_channels,
config.downblock,
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
depth: config.transformer_layers_per_block,
context_dim: Some(config.cross_attention_dim),
num_groups: config.downblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let vs_attn = vs.pp("attentions");
let attentions = (0..config.downblock.num_layers)
.map(|i| {
SpatialTransformer::new(
vs_attn.pp(&i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
use_flash_attn,
cfg,
)
})
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "xa-down2d");
Ok(Self {
downblock,
attentions,
span,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<(Tensor, Vec<Tensor>)> {
let _enter = self.span.enter();
let mut output_states = vec![];
let mut xs = xs.clone();
for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
xs = resnet.forward(&xs, temb)?;
xs = attn.forward(&xs, encoder_hidden_states)?;
output_states.push(xs.clone());
}
let xs = match &self.downblock.downsampler {
Some(downsampler) => {
let xs = downsampler.forward(&xs)?;
output_states.push(xs.clone());
xs
}
None => xs,
};
Ok((xs, output_states))
}
}
#[derive(Debug, Clone, Copy)]
pub struct UpBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
// resnet_time_scale_shift: "default"
// resnet_act_fn: "swish"
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_upsample: bool,
}
impl Default for UpBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_upsample: true,
}
}
}
#[derive(Debug)]
pub struct UpBlock2D {
pub resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>,
span: tracing::Span,
pub config: UpBlock2DConfig,
}
impl UpBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
prev_output_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: UpBlock2DConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let resnet_cfg = ResnetBlock2DConfig {
out_channels: Some(out_channels),
temb_channels,
eps: config.resnet_eps,
output_scale_factor: config.output_scale_factor,
..Default::default()
};
let resnets = (0..config.num_layers)
.map(|i| {
let res_skip_channels = if i == config.num_layers - 1 {
in_channels
} else {
out_channels
};
let resnet_in_channels = if i == 0 {
prev_output_channels
} else {
out_channels
};
let in_channels = resnet_in_channels + res_skip_channels;
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let upsampler = if config.add_upsample {
let upsampler =
Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
Some(upsampler)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "up2d");
Ok(Self {
resnets,
upsampler,
span,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
res_xs: &[Tensor],
temb: Option<&Tensor>,
upsample_size: Option<(usize, usize)>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for (index, resnet) in self.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
xs = xs.contiguous()?;
xs = resnet.forward(&xs, temb)?;
}
match &self.upsampler {
Some(upsampler) => upsampler.forward(&xs, upsample_size),
None => Ok(xs),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct CrossAttnUpBlock2DConfig {
pub upblock: UpBlock2DConfig,
pub attn_num_head_channels: usize,
pub cross_attention_dim: usize,
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
pub transformer_layers_per_block: usize,
}
impl Default for CrossAttnUpBlock2DConfig {
fn default() -> Self {
Self {
upblock: Default::default(),
attn_num_head_channels: 1,
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
transformer_layers_per_block: 1,
}
}
}
#[derive(Debug)]
pub struct CrossAttnUpBlock2D {
pub upblock: UpBlock2D,
pub attentions: Vec<SpatialTransformer>,
span: tracing::Span,
pub config: CrossAttnUpBlock2DConfig,
}
impl CrossAttnUpBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
prev_output_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
use_flash_attn: bool,
config: CrossAttnUpBlock2DConfig,
) -> Result<Self> {
let upblock = UpBlock2D::new(
vs.clone(),
in_channels,
prev_output_channels,
out_channels,
temb_channels,
config.upblock,
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
depth: config.transformer_layers_per_block,
context_dim: Some(config.cross_attention_dim),
num_groups: config.upblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let vs_attn = vs.pp("attentions");
let attentions = (0..config.upblock.num_layers)
.map(|i| {
SpatialTransformer::new(
vs_attn.pp(&i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
use_flash_attn,
cfg,
)
})
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "xa-up2d");
Ok(Self {
upblock,
attentions,
span,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
res_xs: &[Tensor],
temb: Option<&Tensor>,
upsample_size: Option<(usize, usize)>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
xs = xs.contiguous()?;
xs = resnet.forward(&xs, temb)?;
xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
}
match &self.upblock.upsampler {
Some(upsampler) => upsampler.forward(&xs, upsample_size),
None => Ok(xs),
}
}
}

View File

@ -0,0 +1,39 @@
use candle::{Device, Result, Tensor};
use candle_nn::Module;
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
if steps < 1 {
candle::bail!("cannot use linspace with steps {steps} <= 1")
}
let delta = (stop - start) / (steps - 1) as f64;
let vs = (0..steps)
.map(|step| start + step as f64 * delta)
.collect::<Vec<_>>();
Tensor::from_vec(vs, steps, &Device::Cpu)
}
// Wrap the conv2d op to provide some tracing.
#[derive(Debug)]
pub struct Conv2d {
inner: candle_nn::Conv2d,
span: tracing::Span,
}
impl Conv2d {
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
pub fn conv2d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: candle_nn::Conv2dConfig,
vs: candle_nn::VarBuilder,
) -> Result<Conv2d> {
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
Ok(Conv2d { inner, span })
}

View File

@ -0,0 +1,379 @@
#![allow(dead_code)]
//! # Variational Auto-Encoder (VAE) Models.
//!
//! Auto-encoder models compress their input to a usually smaller latent space
//! before expanding it back to its original shape. This results in the latent values
//! compressing the original information.
use super::unet_2d_blocks::{
DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
UpDecoderBlock2D, UpDecoderBlock2DConfig,
};
use candle::{Result, Tensor};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug, Clone)]
struct EncoderConfig {
// down_block_types: DownEncoderBlock2D
block_out_channels: Vec<usize>,
layers_per_block: usize,
norm_num_groups: usize,
double_z: bool,
}
impl Default for EncoderConfig {
fn default() -> Self {
Self {
block_out_channels: vec![64],
layers_per_block: 2,
norm_num_groups: 32,
double_z: true,
}
}
}
#[derive(Debug)]
struct Encoder {
conv_in: nn::Conv2d,
down_blocks: Vec<DownEncoderBlock2D>,
mid_block: UNetMidBlock2D,
conv_norm_out: nn::GroupNorm,
conv_out: nn::Conv2d,
#[allow(dead_code)]
config: EncoderConfig,
}
impl Encoder {
fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: EncoderConfig,
) -> Result<Self> {
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_in = nn::conv2d(
in_channels,
config.block_out_channels[0],
3,
conv_cfg,
vs.pp("conv_in"),
)?;
let mut down_blocks = vec![];
let vs_down_blocks = vs.pp("down_blocks");
for index in 0..config.block_out_channels.len() {
let out_channels = config.block_out_channels[index];
let in_channels = if index > 0 {
config.block_out_channels[index - 1]
} else {
config.block_out_channels[0]
};
let is_final = index + 1 == config.block_out_channels.len();
let cfg = DownEncoderBlock2DConfig {
num_layers: config.layers_per_block,
resnet_eps: 1e-6,
resnet_groups: config.norm_num_groups,
add_downsample: !is_final,
downsample_padding: 0,
..Default::default()
};
let down_block = DownEncoderBlock2D::new(
vs_down_blocks.pp(&index.to_string()),
in_channels,
out_channels,
cfg,
)?;
down_blocks.push(down_block)
}
let last_block_out_channels = *config.block_out_channels.last().unwrap();
let mid_cfg = UNetMidBlock2DConfig {
resnet_eps: 1e-6,
output_scale_factor: 1.,
attn_num_head_channels: None,
resnet_groups: Some(config.norm_num_groups),
..Default::default()
};
let mid_block =
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
last_block_out_channels,
1e-6,
vs.pp("conv_norm_out"),
)?;
let conv_out_channels = if config.double_z {
2 * out_channels
} else {
out_channels
};
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_out = nn::conv2d(
last_block_out_channels,
conv_out_channels,
3,
conv_cfg,
vs.pp("conv_out"),
)?;
Ok(Self {
conv_in,
down_blocks,
mid_block,
conv_norm_out,
conv_out,
config,
})
}
}
impl Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.conv_in.forward(xs)?;
for down_block in self.down_blocks.iter() {
xs = down_block.forward(&xs)?
}
let xs = self.mid_block.forward(&xs, None)?;
let xs = self.conv_norm_out.forward(&xs)?;
let xs = nn::ops::silu(&xs)?;
self.conv_out.forward(&xs)
}
}
#[derive(Debug, Clone)]
struct DecoderConfig {
// up_block_types: UpDecoderBlock2D
block_out_channels: Vec<usize>,
layers_per_block: usize,
norm_num_groups: usize,
}
impl Default for DecoderConfig {
fn default() -> Self {
Self {
block_out_channels: vec![64],
layers_per_block: 2,
norm_num_groups: 32,
}
}
}
#[derive(Debug)]
struct Decoder {
conv_in: nn::Conv2d,
up_blocks: Vec<UpDecoderBlock2D>,
mid_block: UNetMidBlock2D,
conv_norm_out: nn::GroupNorm,
conv_out: nn::Conv2d,
#[allow(dead_code)]
config: DecoderConfig,
}
impl Decoder {
fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: DecoderConfig,
) -> Result<Self> {
let n_block_out_channels = config.block_out_channels.len();
let last_block_out_channels = *config.block_out_channels.last().unwrap();
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_in = nn::conv2d(
in_channels,
last_block_out_channels,
3,
conv_cfg,
vs.pp("conv_in"),
)?;
let mid_cfg = UNetMidBlock2DConfig {
resnet_eps: 1e-6,
output_scale_factor: 1.,
attn_num_head_channels: None,
resnet_groups: Some(config.norm_num_groups),
..Default::default()
};
let mid_block =
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
let mut up_blocks = vec![];
let vs_up_blocks = vs.pp("up_blocks");
let reversed_block_out_channels: Vec<_> =
config.block_out_channels.iter().copied().rev().collect();
for index in 0..n_block_out_channels {
let out_channels = reversed_block_out_channels[index];
let in_channels = if index > 0 {
reversed_block_out_channels[index - 1]
} else {
reversed_block_out_channels[0]
};
let is_final = index + 1 == n_block_out_channels;
let cfg = UpDecoderBlock2DConfig {
num_layers: config.layers_per_block + 1,
resnet_eps: 1e-6,
resnet_groups: config.norm_num_groups,
add_upsample: !is_final,
..Default::default()
};
let up_block = UpDecoderBlock2D::new(
vs_up_blocks.pp(&index.to_string()),
in_channels,
out_channels,
cfg,
)?;
up_blocks.push(up_block)
}
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
config.block_out_channels[0],
1e-6,
vs.pp("conv_norm_out"),
)?;
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_out = nn::conv2d(
config.block_out_channels[0],
out_channels,
3,
conv_cfg,
vs.pp("conv_out"),
)?;
Ok(Self {
conv_in,
up_blocks,
mid_block,
conv_norm_out,
conv_out,
config,
})
}
}
impl Decoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?;
for up_block in self.up_blocks.iter() {
xs = up_block.forward(&xs)?
}
let xs = self.conv_norm_out.forward(&xs)?;
let xs = nn::ops::silu(&xs)?;
self.conv_out.forward(&xs)
}
}
#[derive(Debug, Clone)]
pub struct AutoEncoderKLConfig {
pub block_out_channels: Vec<usize>,
pub layers_per_block: usize,
pub latent_channels: usize,
pub norm_num_groups: usize,
}
impl Default for AutoEncoderKLConfig {
fn default() -> Self {
Self {
block_out_channels: vec![64],
layers_per_block: 1,
latent_channels: 4,
norm_num_groups: 32,
}
}
}
pub struct DiagonalGaussianDistribution {
mean: Tensor,
std: Tensor,
}
impl DiagonalGaussianDistribution {
pub fn new(parameters: &Tensor) -> Result<Self> {
let mut parameters = parameters.chunk(2, 1)?.into_iter();
let mean = parameters.next().unwrap();
let logvar = parameters.next().unwrap();
let std = (logvar * 0.5)?.exp()?;
Ok(DiagonalGaussianDistribution { mean, std })
}
pub fn sample(&self) -> Result<Tensor> {
let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device());
&self.mean + &self.std * sample
}
}
// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485
// This implementation is specific to the config used in stable-diffusion-v1-5
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
#[derive(Debug)]
pub struct AutoEncoderKL {
encoder: Encoder,
decoder: Decoder,
quant_conv: nn::Conv2d,
post_quant_conv: nn::Conv2d,
pub config: AutoEncoderKLConfig,
}
impl AutoEncoderKL {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: AutoEncoderKLConfig,
) -> Result<Self> {
let latent_channels = config.latent_channels;
let encoder_cfg = EncoderConfig {
block_out_channels: config.block_out_channels.clone(),
layers_per_block: config.layers_per_block,
norm_num_groups: config.norm_num_groups,
double_z: true,
};
let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?;
let decoder_cfg = DecoderConfig {
block_out_channels: config.block_out_channels.clone(),
layers_per_block: config.layers_per_block,
norm_num_groups: config.norm_num_groups,
};
let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
let conv_cfg = Default::default();
let quant_conv = nn::conv2d(
2 * latent_channels,
2 * latent_channels,
1,
conv_cfg,
vs.pp("quant_conv"),
)?;
let post_quant_conv = nn::conv2d(
latent_channels,
latent_channels,
1,
conv_cfg,
vs.pp("post_quant_conv"),
)?;
Ok(Self {
encoder,
decoder,
quant_conv,
post_quant_conv,
config,
})
}
/// Returns the distribution in the latent space.
pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
let xs = self.encoder.forward(xs)?;
let parameters = self.quant_conv.forward(&xs)?;
DiagonalGaussianDistribution::new(&parameters)
}
/// Takes as input some sampled values.
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.post_quant_conv.forward(xs)?;
self.decoder.forward(&xs)
}
}