Add a stable diffusion example (#328)

* Start adding a stable-diffusion example.

* Proper computation of the causal mask.

* Add the chunk operation.

* Work in progress: port the attention module.

* Add some dummy modules for conv2d and group-norm, get the attention module to compile.

* Re-enable the 2d convolution.

* Add the embeddings module.

* Add the resnet module.

* Add the unet blocks.

* Add the unet.

* And add the variational auto-encoder.

* Use the pad function from utils.
This commit is contained in:
Laurent Mazare
2023-08-06 18:49:43 +02:00
committed by GitHub
parent 93cfe5642f
commit d34039e352
14 changed files with 2722 additions and 1 deletions

View File

@ -0,0 +1,445 @@
#![allow(dead_code)]
//! Attention Based Building Blocks
use candle::{IndexOp, Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
struct GeGlu {
proj: nn::Linear,
}
impl GeGlu {
fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
Ok(Self { proj })
}
}
impl GeGlu {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
}
}
/// A feed-forward layer.
#[derive(Debug)]
struct FeedForward {
project_in: GeGlu,
linear: nn::Linear,
}
impl FeedForward {
// The glu parameter in the python code is unused?
// https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347
/// Creates a new feed-forward layer based on some given input dimension, some
/// output dimension, and a multiplier to be used for the intermediary layer.
fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
let inner_dim = dim * mult;
let dim_out = dim_out.unwrap_or(dim);
let vs = vs.pp("net");
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
Ok(Self { project_in, linear })
}
}
impl FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.project_in.forward(xs)?;
self.linear.forward(&xs)
}
}
#[derive(Debug)]
struct CrossAttention {
to_q: nn::Linear,
to_k: nn::Linear,
to_v: nn::Linear,
to_out: nn::Linear,
heads: usize,
scale: f64,
slice_size: Option<usize>,
}
impl CrossAttention {
// Defaults should be heads = 8, dim_head = 64, context_dim = None
fn new(
vs: nn::VarBuilder,
query_dim: usize,
context_dim: Option<usize>,
heads: usize,
dim_head: usize,
slice_size: Option<usize>,
) -> Result<Self> {
let inner_dim = dim_head * heads;
let context_dim = context_dim.unwrap_or(query_dim);
let scale = 1.0 / f64::sqrt(dim_head as f64);
let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
Ok(Self {
to_q,
to_k,
to_v,
to_out,
heads,
scale,
slice_size,
})
}
fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
let (batch_size, seq_len, dim) = xs.dims3()?;
xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
.transpose(1, 2)?
.reshape((batch_size * self.heads, seq_len, dim / self.heads))
}
fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
let (batch_size, seq_len, dim) = xs.dims3()?;
xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
.transpose(1, 2)?
.reshape((batch_size / self.heads, seq_len, dim * self.heads))
}
fn sliced_attention(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
slice_size: usize,
) -> Result<Tensor> {
let batch_size_attention = query.dim(0)?;
let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
for i in 0..batch_size_attention / slice_size {
let start_idx = i * slice_size;
let end_idx = (i + 1) * slice_size;
let xs = query
.i(start_idx..end_idx)?
.matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?;
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
hidden_states.push(xs)
}
let hidden_states = Tensor::stack(&hidden_states, 0)?;
self.reshape_batch_dim_to_heads(&hidden_states)
}
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
self.reshape_batch_dim_to_heads(&xs)
}
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let query = self.to_q.forward(xs)?;
let context = context.unwrap_or(xs);
let key = self.to_k.forward(context)?;
let value = self.to_v.forward(context)?;
let query = self.reshape_heads_to_batch_dim(&query)?;
let key = self.reshape_heads_to_batch_dim(&key)?;
let value = self.reshape_heads_to_batch_dim(&value)?;
let xs = match self.slice_size {
None => self.attention(&query, &key, &value)?,
Some(slice_size) => {
if query.dim(0)? / slice_size <= 1 {
self.attention(&query, &key, &value)?
} else {
self.sliced_attention(&query, &key, &value, slice_size)?
}
}
};
self.to_out.forward(&xs)
}
}
/// A basic Transformer block.
#[derive(Debug)]
struct BasicTransformerBlock {
attn1: CrossAttention,
ff: FeedForward,
attn2: CrossAttention,
norm1: nn::LayerNorm,
norm2: nn::LayerNorm,
norm3: nn::LayerNorm,
}
impl BasicTransformerBlock {
fn new(
vs: nn::VarBuilder,
dim: usize,
n_heads: usize,
d_head: usize,
context_dim: Option<usize>,
sliced_attention_size: Option<usize>,
) -> Result<Self> {
let attn1 = CrossAttention::new(
vs.pp("attn1"),
dim,
None,
n_heads,
d_head,
sliced_attention_size,
)?;
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
let attn2 = CrossAttention::new(
vs.pp("attn2"),
dim,
context_dim,
n_heads,
d_head,
sliced_attention_size,
)?;
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
Ok(Self {
attn1,
ff,
attn2,
norm1,
norm2,
norm3,
})
}
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
self.ff.forward(&self.norm3.forward(&xs)?)? + xs
}
}
#[derive(Debug, Clone, Copy)]
pub struct SpatialTransformerConfig {
pub depth: usize,
pub num_groups: usize,
pub context_dim: Option<usize>,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for SpatialTransformerConfig {
fn default() -> Self {
Self {
depth: 1,
num_groups: 32,
context_dim: None,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
enum Proj {
Conv2d(nn::Conv2d),
Linear(nn::Linear),
}
// Aka Transformer2DModel
#[derive(Debug)]
pub struct SpatialTransformer {
norm: nn::GroupNorm,
proj_in: Proj,
transformer_blocks: Vec<BasicTransformerBlock>,
proj_out: Proj,
pub config: SpatialTransformerConfig,
}
impl SpatialTransformer {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
n_heads: usize,
d_head: usize,
config: SpatialTransformerConfig,
) -> Result<Self> {
let inner_dim = n_heads * d_head;
let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?;
let proj_in = if config.use_linear_projection {
Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?)
} else {
Proj::Conv2d(nn::conv2d(
in_channels,
inner_dim,
1,
Default::default(),
vs.pp("proj_in"),
)?)
};
let mut transformer_blocks = vec![];
let vs_tb = vs.pp("transformer_blocks");
for index in 0..config.depth {
let tb = BasicTransformerBlock::new(
vs_tb.pp(&index.to_string()),
inner_dim,
n_heads,
d_head,
config.context_dim,
config.sliced_attention_size,
)?;
transformer_blocks.push(tb)
}
let proj_out = if config.use_linear_projection {
Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?)
} else {
Proj::Conv2d(nn::conv2d(
inner_dim,
in_channels,
1,
Default::default(),
vs.pp("proj_out"),
)?)
};
Ok(Self {
norm,
proj_in,
transformer_blocks,
proj_out,
config,
})
}
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let (batch, _channel, height, weight) = xs.dims4()?;
let residual = xs;
let xs = self.norm.forward(xs)?;
let (inner_dim, xs) = match &self.proj_in {
Proj::Conv2d(p) => {
let xs = p.forward(&xs)?;
let inner_dim = xs.dim(1)?;
let xs = xs
.transpose(1, 2)?
.t()?
.reshape((batch, height * weight, inner_dim))?;
(inner_dim, xs)
}
Proj::Linear(p) => {
let inner_dim = xs.dim(1)?;
let xs = xs
.transpose(1, 2)?
.t()?
.reshape((batch, height * weight, inner_dim))?;
(inner_dim, p.forward(&xs)?)
}
};
let mut xs = xs;
for block in self.transformer_blocks.iter() {
xs = block.forward(&xs, context)?
}
let xs = match &self.proj_out {
Proj::Conv2d(p) => p.forward(
&xs.reshape((batch, height, weight, inner_dim))?
.t()?
.transpose(1, 2)?,
)?,
Proj::Linear(p) => p
.forward(&xs)?
.reshape((batch, height, weight, inner_dim))?
.t()?
.transpose(1, 2)?,
};
xs + residual
}
}
/// Configuration for an attention block.
#[derive(Debug, Clone, Copy)]
pub struct AttentionBlockConfig {
pub num_head_channels: Option<usize>,
pub num_groups: usize,
pub rescale_output_factor: f64,
pub eps: f64,
}
impl Default for AttentionBlockConfig {
fn default() -> Self {
Self {
num_head_channels: None,
num_groups: 32,
rescale_output_factor: 1.,
eps: 1e-5,
}
}
}
#[derive(Debug)]
pub struct AttentionBlock {
group_norm: nn::GroupNorm,
query: nn::Linear,
key: nn::Linear,
value: nn::Linear,
proj_attn: nn::Linear,
channels: usize,
num_heads: usize,
config: AttentionBlockConfig,
}
impl AttentionBlock {
pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
let num_head_channels = config.num_head_channels.unwrap_or(channels);
let num_heads = channels / num_head_channels;
let group_norm =
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
let query = nn::linear(channels, channels, vs.pp("query"))?;
let key = nn::linear(channels, channels, vs.pp("key"))?;
let value = nn::linear(channels, channels, vs.pp("value"))?;
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
Ok(Self {
group_norm,
query,
key,
value,
proj_attn,
channels,
num_heads,
config,
})
}
fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> {
let (batch, t, h_times_d) = xs.dims3()?;
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
.transpose(1, 2)
}
}
impl AttentionBlock {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let (batch, channel, height, width) = xs.dims4()?;
let xs = self
.group_norm
.forward(xs)?
.reshape((batch, channel, height * width))?
.transpose(1, 2)?;
let query_proj = self.query.forward(&xs)?;
let key_proj = self.key.forward(&xs)?;
let value_proj = self.value.forward(&xs)?;
let query_states = self.transpose_for_scores(query_proj)?;
let key_states = self.transpose_for_scores(key_proj)?;
let value_states = self.transpose_for_scores(value_proj)?;
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
let attention_scores =
// TODO: Check that this needs two multiplication by `scale`.
(query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
let xs = attention_probs.matmul(&value_states)?;
let xs = xs.transpose(1, 2)?.contiguous()?;
let xs = xs.flatten_from(D::Minus2)?;
let xs = self
.proj_attn
.forward(&xs)?
.t()?
.reshape((batch, channel, height, width))?;
(xs + residual)? / self.config.rescale_output_factor
}
}

View File

@ -0,0 +1,304 @@
#![allow(dead_code)]
//! Contrastive Language-Image Pre-Training
//!
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
//! pairs of images with related texts.
//!
//! https://github.com/openai/CLIP
use candle::{Device, Result, Tensor, D};
#[derive(Debug, Clone, Copy)]
pub enum Activation {
QuickGelu,
Gelu,
}
impl Activation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Activation::QuickGelu => xs * crate::utils::sigmoid(&(xs * 1.702f64)?)?,
Activation::Gelu => xs.gelu(),
}
}
}
#[derive(Debug, Clone)]
pub struct Config {
vocab_size: usize,
embed_dim: usize, // aka config.hidden_size
activation: Activation, // aka config.hidden_act
intermediate_size: usize,
max_position_embeddings: usize,
// The character to use for padding, use EOS when not set.
pad_with: Option<String>,
num_hidden_layers: usize,
num_attention_heads: usize,
#[allow(dead_code)]
projection_dim: usize,
}
impl Config {
// The config details can be found in the "text_config" section of this json file:
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
pub fn v1_5() -> Self {
Self {
vocab_size: 49408,
embed_dim: 768,
intermediate_size: 3072,
max_position_embeddings: 77,
pad_with: None,
num_hidden_layers: 12,
num_attention_heads: 12,
projection_dim: 768,
activation: Activation::QuickGelu,
}
}
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json
pub fn v2_1() -> Self {
Self {
vocab_size: 49408,
embed_dim: 1024,
intermediate_size: 4096,
max_position_embeddings: 77,
pad_with: Some("!".to_string()),
num_hidden_layers: 23,
num_attention_heads: 16,
projection_dim: 512,
activation: Activation::Gelu,
}
}
}
// CLIP Text Model
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py
#[derive(Debug)]
struct ClipTextEmbeddings {
token_embedding: candle_nn::Embedding,
position_embedding: candle_nn::Embedding,
position_ids: Tensor,
}
impl ClipTextEmbeddings {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let token_embedding =
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
let position_embedding = candle_nn::embedding(
c.max_position_embeddings,
c.embed_dim,
vs.pp("position_embedding"),
)?;
let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(1)?;
Ok(ClipTextEmbeddings {
token_embedding,
position_embedding,
position_ids,
})
}
}
impl ClipTextEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let token_embedding = self.token_embedding.forward(xs)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
token_embedding + position_embedding
}
}
#[derive(Debug)]
struct ClipAttention {
k_proj: candle_nn::Linear,
v_proj: candle_nn::Linear,
q_proj: candle_nn::Linear,
out_proj: candle_nn::Linear,
head_dim: usize,
scale: f64,
num_attention_heads: usize,
}
impl ClipAttention {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let embed_dim = c.embed_dim;
let num_attention_heads = c.num_attention_heads;
let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
let head_dim = embed_dim / num_attention_heads;
let scale = (head_dim as f64).powf(-0.5);
Ok(ClipAttention {
k_proj,
v_proj,
q_proj,
out_proj,
head_dim,
scale,
num_attention_heads,
})
}
fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()
}
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
let (bsz, seq_len, embed_dim) = xs.dims3()?;
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
let query_states = self
.shape(&query_states, seq_len, bsz)?
.reshape(proj_shape)?;
let key_states = self
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?;
let value_states = self
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let src_len = key_states.dim(1)?;
let attn_weights =
(attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
+ causal_attention_mask)?;
let attn_weights =
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_output = attn_weights.matmul(&value_states)?;
let attn_output = attn_output
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
.transpose(1, 2)?
.reshape((bsz, seq_len, embed_dim))?;
self.out_proj.forward(&attn_output)
}
}
#[derive(Debug)]
struct ClipMlp {
fc1: candle_nn::Linear,
fc2: candle_nn::Linear,
activation: Activation,
}
impl ClipMlp {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?;
let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?;
Ok(ClipMlp {
fc1,
fc2,
activation: c.activation,
})
}
}
impl ClipMlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?;
self.fc2.forward(&self.activation.forward(&xs)?)
}
}
#[derive(Debug)]
struct ClipEncoderLayer {
self_attn: ClipAttention,
layer_norm1: candle_nn::LayerNorm,
mlp: ClipMlp,
layer_norm2: candle_nn::LayerNorm,
}
impl ClipEncoderLayer {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?;
let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?;
Ok(ClipEncoderLayer {
self_attn,
layer_norm1,
mlp,
layer_norm2,
})
}
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = self.layer_norm1.forward(xs)?;
let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self.layer_norm2.forward(&xs)?;
let xs = self.mlp.forward(&xs)?;
xs + residual
}
}
#[derive(Debug)]
struct ClipEncoder {
layers: Vec<ClipEncoderLayer>,
}
impl ClipEncoder {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let vs = vs.pp("layers");
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
for index in 0..c.num_hidden_layers {
let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
layers.push(layer)
}
Ok(ClipEncoder { layers })
}
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?
}
Ok(xs)
}
}
/// A CLIP transformer based model.
#[derive(Debug)]
pub struct ClipTextTransformer {
embeddings: ClipTextEmbeddings,
encoder: ClipEncoder,
final_layer_norm: candle_nn::LayerNorm,
}
impl ClipTextTransformer {
pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let vs = vs.pp("text_model");
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
Ok(ClipTextTransformer {
embeddings,
encoder,
final_layer_norm,
})
}
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
mask.broadcast_as((bsz, seq_len, seq_len))
}
}
impl ClipTextTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (bsz, seq_len) = xs.dims2()?;
let xs = self.embeddings.forward(xs)?;
let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
self.final_layer_norm.forward(&xs)
}
}

View File

@ -0,0 +1,65 @@
#![allow(dead_code)]
use candle::{Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
pub struct TimestepEmbedding {
linear_1: nn::Linear,
linear_2: nn::Linear,
}
impl TimestepEmbedding {
// act_fn: "silu"
pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
Ok(Self { linear_1, linear_2 })
}
}
impl TimestepEmbedding {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
self.linear_2.forward(&xs)
}
}
#[derive(Debug)]
pub struct Timesteps {
num_channels: usize,
flip_sin_to_cos: bool,
downscale_freq_shift: f64,
}
impl Timesteps {
pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
Self {
num_channels,
flip_sin_to_cos,
downscale_freq_shift,
}
}
}
impl Timesteps {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let half_dim = (self.num_channels / 2) as u32;
let exponent =
(Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
let emb = exponent.exp()?;
// emb = timesteps[:, None].float() * emb[None, :]
let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?;
let (cos, sin) = (emb.cos()?, emb.sin()?);
let emb = if self.flip_sin_to_cos {
Tensor::cat(&[&cos, &sin], D::Minus1)?
} else {
Tensor::cat(&[&sin, &cos], D::Minus1)?
};
if self.num_channels % 2 == 1 {
crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None)
} else {
Ok(emb)
}
}
}

View File

@ -0,0 +1,30 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod attention;
mod clip;
mod embeddings;
mod resnet;
mod unet_2d;
mod unet_2d_blocks;
mod utils;
mod vae;
use anyhow::Result;
use clap::Parser;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
prompt: String,
}
fn main() -> Result<()> {
let _args = Args::parse();
Ok(())
}

View File

@ -0,0 +1,129 @@
#![allow(dead_code)]
//! ResNet Building Blocks
//!
//! Some Residual Network blocks used in UNet models.
//!
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
//! https://arxiv.org/abs/1512.03385
use candle::{Result, Tensor, D};
use candle_nn as nn;
/// Configuration for a ResNet block.
#[derive(Debug, Clone, Copy)]
pub struct ResnetBlock2DConfig {
/// The number of output channels, defaults to the number of input channels.
pub out_channels: Option<usize>,
pub temb_channels: Option<usize>,
/// The number of groups to use in group normalization.
pub groups: usize,
pub groups_out: Option<usize>,
/// The epsilon to be used in the group normalization operations.
pub eps: f64,
/// Whether to use a 2D convolution in the skip connection. When using None,
/// such a convolution is used if the number of input channels is different from
/// the number of output channels.
pub use_in_shortcut: Option<bool>,
// non_linearity: silu
/// The final output is scaled by dividing by this value.
pub output_scale_factor: f64,
}
impl Default for ResnetBlock2DConfig {
fn default() -> Self {
Self {
out_channels: None,
temb_channels: Some(512),
groups: 32,
groups_out: None,
eps: 1e-6,
use_in_shortcut: None,
output_scale_factor: 1.,
}
}
}
#[derive(Debug)]
pub struct ResnetBlock2D {
norm1: nn::GroupNorm,
conv1: nn::Conv2d,
norm2: nn::GroupNorm,
conv2: nn::Conv2d,
time_emb_proj: Option<nn::Linear>,
conv_shortcut: Option<nn::Conv2d>,
config: ResnetBlock2DConfig,
}
impl ResnetBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
config: ResnetBlock2DConfig,
) -> Result<Self> {
let out_channels = config.out_channels.unwrap_or(in_channels);
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 1,
};
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
let groups_out = config.groups_out.unwrap_or(config.groups);
let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
let use_in_shortcut = config
.use_in_shortcut
.unwrap_or(in_channels != out_channels);
let conv_shortcut = if use_in_shortcut {
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 0,
};
Some(nn::conv2d(
in_channels,
out_channels,
1,
conv_cfg,
vs.pp("conv_shortcut"),
)?)
} else {
None
};
let time_emb_proj = match config.temb_channels {
None => None,
Some(temb_channels) => Some(nn::linear(
temb_channels,
out_channels,
vs.pp("time_emb_proj"),
)?),
};
Ok(Self {
norm1,
conv1,
norm2,
conv2,
time_emb_proj,
config,
conv_shortcut,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let shortcut_xs = match &self.conv_shortcut {
Some(conv_shortcut) => conv_shortcut.forward(xs)?,
None => xs.clone(),
};
let xs = self.norm1.forward(xs)?;
let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
let xs = match (temb, &self.time_emb_proj) {
(Some(temb), Some(time_emb_proj)) => time_emb_proj
.forward(&nn::ops::silu(temb)?)?
.unsqueeze(D::Minus1)?
.unsqueeze(D::Minus1)?
.add(&xs)?,
_ => xs,
};
let xs = self
.conv2
.forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
(shortcut_xs + xs)? / self.config.output_scale_factor
}
}

View File

@ -0,0 +1,383 @@
#![allow(dead_code)]
//! 2D UNet Denoising Models
//!
//! The 2D Unet models take as input a noisy sample and the current diffusion
//! timestep and return a denoised version of the input.
use crate::embeddings::{TimestepEmbedding, Timesteps};
use crate::unet_2d_blocks::*;
use candle::{DType, Result, Tensor};
use candle_nn as nn;
#[derive(Debug, Clone, Copy)]
pub struct BlockConfig {
pub out_channels: usize,
pub use_cross_attn: bool,
pub attention_head_dim: usize,
}
#[derive(Debug, Clone)]
pub struct UNet2DConditionModelConfig {
pub center_input_sample: bool,
pub flip_sin_to_cos: bool,
pub freq_shift: f64,
pub blocks: Vec<BlockConfig>,
pub layers_per_block: usize,
pub downsample_padding: usize,
pub mid_block_scale_factor: f64,
pub norm_num_groups: usize,
pub norm_eps: f64,
pub cross_attention_dim: usize,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for UNet2DConditionModelConfig {
fn default() -> Self {
Self {
center_input_sample: false,
flip_sin_to_cos: true,
freq_shift: 0.,
blocks: vec![
BlockConfig {
out_channels: 320,
use_cross_attn: true,
attention_head_dim: 8,
},
BlockConfig {
out_channels: 640,
use_cross_attn: true,
attention_head_dim: 8,
},
BlockConfig {
out_channels: 1280,
use_cross_attn: true,
attention_head_dim: 8,
},
BlockConfig {
out_channels: 1280,
use_cross_attn: false,
attention_head_dim: 8,
},
],
layers_per_block: 2,
downsample_padding: 1,
mid_block_scale_factor: 1.,
norm_num_groups: 32,
norm_eps: 1e-5,
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub(crate) enum UNetDownBlock {
Basic(DownBlock2D),
CrossAttn(CrossAttnDownBlock2D),
}
#[derive(Debug)]
enum UNetUpBlock {
Basic(UpBlock2D),
CrossAttn(CrossAttnUpBlock2D),
}
#[derive(Debug)]
pub struct UNet2DConditionModel {
conv_in: nn::Conv2d,
time_proj: Timesteps,
time_embedding: TimestepEmbedding,
down_blocks: Vec<UNetDownBlock>,
mid_block: UNetMidBlock2DCrossAttn,
up_blocks: Vec<UNetUpBlock>,
conv_norm_out: nn::GroupNorm,
conv_out: nn::Conv2d,
config: UNet2DConditionModelConfig,
}
impl UNet2DConditionModel {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: UNet2DConditionModelConfig,
) -> Result<Self> {
let n_blocks = config.blocks.len();
let b_channels = config.blocks[0].out_channels;
let bl_channels = config.blocks.last().unwrap().out_channels;
let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
let time_embed_dim = b_channels * 4;
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 1,
};
let conv_in = nn::conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
let time_embedding =
TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
let vs_db = vs.pp("down_blocks");
let down_blocks = (0..n_blocks)
.map(|i| {
let BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
} = config.blocks[i];
// Enable automatic attention slicing if the config sliced_attention_size is set to 0.
let sliced_attention_size = match config.sliced_attention_size {
Some(0) => Some(attention_head_dim / 2),
_ => config.sliced_attention_size,
};
let in_channels = if i > 0 {
config.blocks[i - 1].out_channels
} else {
b_channels
};
let db_cfg = DownBlock2DConfig {
num_layers: config.layers_per_block,
resnet_eps: config.norm_eps,
resnet_groups: config.norm_num_groups,
add_downsample: i < n_blocks - 1,
downsample_padding: config.downsample_padding,
..Default::default()
};
if use_cross_attn {
let config = CrossAttnDownBlock2DConfig {
downblock: db_cfg,
attn_num_head_channels: attention_head_dim,
cross_attention_dim: config.cross_attention_dim,
sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let block = CrossAttnDownBlock2D::new(
vs_db.pp(&i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
config,
)?;
Ok(UNetDownBlock::CrossAttn(block))
} else {
let block = DownBlock2D::new(
vs_db.pp(&i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
db_cfg,
)?;
Ok(UNetDownBlock::Basic(block))
}
})
.collect::<Result<Vec<_>>>()?;
let mid_cfg = UNetMidBlock2DCrossAttnConfig {
resnet_eps: config.norm_eps,
output_scale_factor: config.mid_block_scale_factor,
cross_attn_dim: config.cross_attention_dim,
attn_num_head_channels: bl_attention_head_dim,
resnet_groups: Some(config.norm_num_groups),
use_linear_projection: config.use_linear_projection,
..Default::default()
};
let mid_block = UNetMidBlock2DCrossAttn::new(
vs.pp("mid_block"),
bl_channels,
Some(time_embed_dim),
mid_cfg,
)?;
let vs_ub = vs.pp("up_blocks");
let up_blocks = (0..n_blocks)
.map(|i| {
let BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
} = config.blocks[n_blocks - 1 - i];
// Enable automatic attention slicing if the config sliced_attention_size is set to 0.
let sliced_attention_size = match config.sliced_attention_size {
Some(0) => Some(attention_head_dim / 2),
_ => config.sliced_attention_size,
};
let prev_out_channels = if i > 0 {
config.blocks[n_blocks - i].out_channels
} else {
bl_channels
};
let in_channels = {
let index = if i == n_blocks - 1 {
0
} else {
n_blocks - i - 2
};
config.blocks[index].out_channels
};
let ub_cfg = UpBlock2DConfig {
num_layers: config.layers_per_block + 1,
resnet_eps: config.norm_eps,
resnet_groups: config.norm_num_groups,
add_upsample: i < n_blocks - 1,
..Default::default()
};
if use_cross_attn {
let config = CrossAttnUpBlock2DConfig {
upblock: ub_cfg,
attn_num_head_channels: attention_head_dim,
cross_attention_dim: config.cross_attention_dim,
sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let block = CrossAttnUpBlock2D::new(
vs_ub.pp(&i.to_string()),
in_channels,
prev_out_channels,
out_channels,
Some(time_embed_dim),
config,
)?;
Ok(UNetUpBlock::CrossAttn(block))
} else {
let block = UpBlock2D::new(
vs_ub.pp(&i.to_string()),
in_channels,
prev_out_channels,
out_channels,
Some(time_embed_dim),
ub_cfg,
)?;
Ok(UNetUpBlock::Basic(block))
}
})
.collect::<Result<Vec<_>>>()?;
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
b_channels,
config.norm_eps,
vs.pp("conv_norm_out"),
)?;
let conv_out = nn::conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
Ok(Self {
conv_in,
time_proj,
time_embedding,
down_blocks,
mid_block,
up_blocks,
conv_norm_out,
conv_out,
config,
})
}
}
impl UNet2DConditionModel {
pub fn forward(
&self,
xs: &Tensor,
timestep: f64,
encoder_hidden_states: &Tensor,
) -> Result<Tensor> {
self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
}
pub fn forward_with_additional_residuals(
&self,
xs: &Tensor,
timestep: f64,
encoder_hidden_states: &Tensor,
down_block_additional_residuals: Option<&[Tensor]>,
mid_block_additional_residual: Option<&Tensor>,
) -> Result<Tensor> {
let (bsize, _channels, height, width) = xs.dims4()?;
let device = xs.device();
let n_blocks = self.config.blocks.len();
let num_upsamplers = n_blocks - 1;
let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
let forward_upsample_size =
height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
// 0. center input if necessary
let xs = if self.config.center_input_sample {
((xs * 2.0)? - 1.0)?
} else {
xs.clone()
};
// 1. time
let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?;
let emb = self.time_proj.forward(&emb)?;
let emb = self.time_embedding.forward(&emb)?;
// 2. pre-process
let xs = self.conv_in.forward(&xs)?;
// 3. down
let mut down_block_res_xs = vec![xs.clone()];
let mut xs = xs;
for down_block in self.down_blocks.iter() {
let (_xs, res_xs) = match down_block {
UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
UNetDownBlock::CrossAttn(b) => {
b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
}
};
down_block_res_xs.extend(res_xs);
xs = _xs;
}
let new_down_block_res_xs =
if let Some(down_block_additional_residuals) = down_block_additional_residuals {
let mut v = vec![];
// A previous version of this code had a bug because of the addition being made
// in place via += hence modifying the input of the mid block.
for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
v.push((&down_block_res_xs[i] + residuals)?)
}
v
} else {
down_block_res_xs
};
let mut down_block_res_xs = new_down_block_res_xs;
// 4. mid
let xs = self
.mid_block
.forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
let xs = match mid_block_additional_residual {
None => xs,
Some(m) => (m + xs)?,
};
// 5. up
let mut xs = xs;
let mut upsample_size = None;
for (i, up_block) in self.up_blocks.iter().enumerate() {
let n_resnets = match up_block {
UNetUpBlock::Basic(b) => b.resnets.len(),
UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
};
let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
if i < n_blocks - 1 && forward_upsample_size {
let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
upsample_size = Some((h, w))
}
xs = match up_block {
UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
UNetUpBlock::CrossAttn(b) => b.forward(
&xs,
&res_xs,
Some(&emb),
upsample_size,
Some(encoder_hidden_states),
)?,
};
}
// 6. post-process
let xs = self.conv_norm_out.forward(&xs)?;
let xs = nn::ops::silu(&xs)?;
self.conv_out.forward(&xs)
}
}

View File

@ -0,0 +1,809 @@
#![allow(dead_code)]
//! 2D UNet Building Blocks
//!
use crate::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use candle::{Result, Tensor};
use candle_nn as nn;
#[derive(Debug)]
struct Downsample2D {
conv: Option<nn::Conv2d>,
padding: usize,
}
impl Downsample2D {
fn new(
vs: nn::VarBuilder,
in_channels: usize,
use_conv: bool,
out_channels: usize,
padding: usize,
) -> Result<Self> {
let conv = if use_conv {
let config = nn::Conv2dConfig { stride: 2, padding };
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Some(conv)
} else {
None
};
Ok(Downsample2D { conv, padding })
}
}
impl Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match &self.conv {
None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
Some(conv) => {
if self.padding == 0 {
let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?;
conv.forward(&xs)
} else {
conv.forward(xs)
}
}
}
}
}
// This does not support the conv-transpose mode.
#[derive(Debug)]
struct Upsample2D {
conv: nn::Conv2d,
}
impl Upsample2D {
fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
let config = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Ok(Self { conv })
}
}
impl Upsample2D {
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
let xs = match size {
None => {
// The following does not work and it's tricky to pass no fixed
// dimensions so hack our way around this.
// xs.upsample_nearest2d(&[], Some(2.), Some(2.)
let (_bsize, _channels, _h, _w) = xs.dims4()?;
crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.))
}
Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None),
};
self.conv.forward(&xs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct DownEncoderBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_downsample: bool,
pub downsample_padding: usize,
}
impl Default for DownEncoderBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_downsample: true,
downsample_padding: 1,
}
}
}
#[derive(Debug)]
pub struct DownEncoderBlock2D {
resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>,
pub config: DownEncoderBlock2DConfig,
}
impl DownEncoderBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: DownEncoderBlock2DConfig,
) -> Result<Self> {
let resnets: Vec<_> = {
let vs = vs.pp("resnets");
let conv_cfg = ResnetBlock2DConfig {
eps: config.resnet_eps,
out_channels: Some(out_channels),
groups: config.resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels: None,
..Default::default()
};
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
let downsampler = if config.add_downsample {
let downsample = Downsample2D::new(
vs.pp("downsamplers").pp("0"),
out_channels,
true,
out_channels,
config.downsample_padding,
)?;
Some(downsample)
} else {
None
};
Ok(Self {
resnets,
downsampler,
config,
})
}
}
impl DownEncoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)?
}
match &self.downsampler {
Some(downsampler) => downsampler.forward(&xs),
None => Ok(xs),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct UpDecoderBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_upsample: bool,
}
impl Default for UpDecoderBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_upsample: true,
}
}
}
#[derive(Debug)]
pub struct UpDecoderBlock2D {
resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>,
pub config: UpDecoderBlock2DConfig,
}
impl UpDecoderBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: UpDecoderBlock2DConfig,
) -> Result<Self> {
let resnets: Vec<_> = {
let vs = vs.pp("resnets");
let conv_cfg = ResnetBlock2DConfig {
out_channels: Some(out_channels),
eps: config.resnet_eps,
groups: config.resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels: None,
..Default::default()
};
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
let upsampler = if config.add_upsample {
let upsample =
Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
Some(upsample)
} else {
None
};
Ok(Self {
resnets,
upsampler,
config,
})
}
}
impl UpDecoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)?
}
match &self.upsampler {
Some(upsampler) => upsampler.forward(&xs, None),
None => Ok(xs),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct UNetMidBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: Option<usize>,
pub attn_num_head_channels: Option<usize>,
// attention_type "default"
pub output_scale_factor: f64,
}
impl Default for UNetMidBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: Some(32),
attn_num_head_channels: Some(1),
output_scale_factor: 1.,
}
}
}
#[derive(Debug)]
pub struct UNetMidBlock2D {
resnet: ResnetBlock2D,
attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
pub config: UNetMidBlock2DConfig,
}
impl UNetMidBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
temb_channels: Option<usize>,
config: UNetMidBlock2DConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let vs_attns = vs.pp("attentions");
let resnet_groups = config
.resnet_groups
.unwrap_or_else(|| usize::min(in_channels / 4, 32));
let resnet_cfg = ResnetBlock2DConfig {
eps: config.resnet_eps,
groups: resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels,
..Default::default()
};
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
let attn_cfg = AttentionBlockConfig {
num_head_channels: config.attn_num_head_channels,
num_groups: resnet_groups,
rescale_output_factor: config.output_scale_factor,
eps: config.resnet_eps,
};
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
let resnet = ResnetBlock2D::new(
vs_resnets.pp(&(index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
attn_resnets.push((attn, resnet))
}
Ok(Self {
resnet,
attn_resnets,
config,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs)?, temb)?
}
Ok(xs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct UNetMidBlock2DCrossAttnConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: Option<usize>,
pub attn_num_head_channels: usize,
// attention_type "default"
pub output_scale_factor: f64,
pub cross_attn_dim: usize,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for UNetMidBlock2DCrossAttnConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: Some(32),
attn_num_head_channels: 1,
output_scale_factor: 1.,
cross_attn_dim: 1280,
sliced_attention_size: None, // Sliced attention disabled
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub struct UNetMidBlock2DCrossAttn {
resnet: ResnetBlock2D,
attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
pub config: UNetMidBlock2DCrossAttnConfig,
}
impl UNetMidBlock2DCrossAttn {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
temb_channels: Option<usize>,
config: UNetMidBlock2DCrossAttnConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let vs_attns = vs.pp("attentions");
let resnet_groups = config
.resnet_groups
.unwrap_or_else(|| usize::min(in_channels / 4, 32));
let resnet_cfg = ResnetBlock2DConfig {
eps: config.resnet_eps,
groups: resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels,
..Default::default()
};
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
let n_heads = config.attn_num_head_channels;
let attn_cfg = SpatialTransformerConfig {
depth: 1,
num_groups: resnet_groups,
context_dim: Some(config.cross_attn_dim),
sliced_attention_size: config.sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
let attn = SpatialTransformer::new(
vs_attns.pp(&index.to_string()),
in_channels,
n_heads,
in_channels / n_heads,
attn_cfg,
)?;
let resnet = ResnetBlock2D::new(
vs_resnets.pp(&(index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
attn_resnets.push((attn, resnet))
}
Ok(Self {
resnet,
attn_resnets,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
}
Ok(xs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct DownBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
// resnet_time_scale_shift: "default"
// resnet_act_fn: "swish"
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_downsample: bool,
pub downsample_padding: usize,
}
impl Default for DownBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_downsample: true,
downsample_padding: 1,
}
}
}
#[derive(Debug)]
pub struct DownBlock2D {
resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>,
pub config: DownBlock2DConfig,
}
impl DownBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: DownBlock2DConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let resnet_cfg = ResnetBlock2DConfig {
out_channels: Some(out_channels),
eps: config.resnet_eps,
output_scale_factor: config.output_scale_factor,
temb_channels,
..Default::default()
};
let resnets = (0..config.num_layers)
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let downsampler = if config.add_downsample {
let downsampler = Downsample2D::new(
vs.pp("downsamplers").pp("0"),
out_channels,
true,
out_channels,
config.downsample_padding,
)?;
Some(downsampler)
} else {
None
};
Ok(Self {
resnets,
downsampler,
config,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
let mut xs = xs.clone();
let mut output_states = vec![];
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, temb)?;
output_states.push(xs.clone());
}
let xs = match &self.downsampler {
Some(downsampler) => {
let xs = downsampler.forward(&xs)?;
output_states.push(xs.clone());
xs
}
None => xs,
};
Ok((xs, output_states))
}
}
#[derive(Debug, Clone, Copy)]
pub struct CrossAttnDownBlock2DConfig {
pub downblock: DownBlock2DConfig,
pub attn_num_head_channels: usize,
pub cross_attention_dim: usize,
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for CrossAttnDownBlock2DConfig {
fn default() -> Self {
Self {
downblock: Default::default(),
attn_num_head_channels: 1,
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub struct CrossAttnDownBlock2D {
downblock: DownBlock2D,
attentions: Vec<SpatialTransformer>,
pub config: CrossAttnDownBlock2DConfig,
}
impl CrossAttnDownBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: CrossAttnDownBlock2DConfig,
) -> Result<Self> {
let downblock = DownBlock2D::new(
vs.clone(),
in_channels,
out_channels,
temb_channels,
config.downblock,
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
depth: 1,
context_dim: Some(config.cross_attention_dim),
num_groups: config.downblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let vs_attn = vs.pp("attentions");
let attentions = (0..config.downblock.num_layers)
.map(|i| {
SpatialTransformer::new(
vs_attn.pp(&i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
cfg,
)
})
.collect::<Result<Vec<_>>>()?;
Ok(Self {
downblock,
attentions,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<(Tensor, Vec<Tensor>)> {
let mut output_states = vec![];
let mut xs = xs.clone();
for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
xs = resnet.forward(&xs, temb)?;
xs = attn.forward(&xs, encoder_hidden_states)?;
output_states.push(xs.clone());
}
let xs = match &self.downblock.downsampler {
Some(downsampler) => {
let xs = downsampler.forward(&xs)?;
output_states.push(xs.clone());
xs
}
None => xs,
};
Ok((xs, output_states))
}
}
#[derive(Debug, Clone, Copy)]
pub struct UpBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
// resnet_time_scale_shift: "default"
// resnet_act_fn: "swish"
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_upsample: bool,
}
impl Default for UpBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_upsample: true,
}
}
}
#[derive(Debug)]
pub struct UpBlock2D {
pub resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>,
pub config: UpBlock2DConfig,
}
impl UpBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
prev_output_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: UpBlock2DConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let resnet_cfg = ResnetBlock2DConfig {
out_channels: Some(out_channels),
temb_channels,
eps: config.resnet_eps,
output_scale_factor: config.output_scale_factor,
..Default::default()
};
let resnets = (0..config.num_layers)
.map(|i| {
let res_skip_channels = if i == config.num_layers - 1 {
in_channels
} else {
out_channels
};
let resnet_in_channels = if i == 0 {
prev_output_channels
} else {
out_channels
};
let in_channels = resnet_in_channels + res_skip_channels;
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let upsampler = if config.add_upsample {
let upsampler =
Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
Some(upsampler)
} else {
None
};
Ok(Self {
resnets,
upsampler,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
res_xs: &[Tensor],
temb: Option<&Tensor>,
upsample_size: Option<(usize, usize)>,
) -> Result<Tensor> {
let mut xs = xs.clone();
for (index, resnet) in self.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
xs = resnet.forward(&xs, temb)?;
}
match &self.upsampler {
Some(upsampler) => upsampler.forward(&xs, upsample_size),
None => Ok(xs),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct CrossAttnUpBlock2DConfig {
pub upblock: UpBlock2DConfig,
pub attn_num_head_channels: usize,
pub cross_attention_dim: usize,
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for CrossAttnUpBlock2DConfig {
fn default() -> Self {
Self {
upblock: Default::default(),
attn_num_head_channels: 1,
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub struct CrossAttnUpBlock2D {
pub upblock: UpBlock2D,
pub attentions: Vec<SpatialTransformer>,
pub config: CrossAttnUpBlock2DConfig,
}
impl CrossAttnUpBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
prev_output_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: CrossAttnUpBlock2DConfig,
) -> Result<Self> {
let upblock = UpBlock2D::new(
vs.clone(),
in_channels,
prev_output_channels,
out_channels,
temb_channels,
config.upblock,
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
depth: 1,
context_dim: Some(config.cross_attention_dim),
num_groups: config.upblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let vs_attn = vs.pp("attentions");
let attentions = (0..config.upblock.num_layers)
.map(|i| {
SpatialTransformer::new(
vs_attn.pp(&i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
cfg,
)
})
.collect::<Result<Vec<_>>>()?;
Ok(Self {
upblock,
attentions,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
res_xs: &[Tensor],
temb: Option<&Tensor>,
upsample_size: Option<(usize, usize)>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let mut xs = xs.clone();
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
xs = resnet.forward(&xs, temb)?;
xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
}
match &self.upblock.upsampler {
Some(upsampler) => upsampler.forward(&xs, upsample_size),
None => Ok(xs),
}
}
}

View File

@ -0,0 +1,17 @@
use candle::{Result, Tensor};
pub fn sigmoid(_: &Tensor) -> Result<Tensor> {
todo!()
}
pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
todo!()
}
pub fn pad(_: &Tensor) -> Result<Tensor> {
todo!()
}
pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
todo!()
}

View File

@ -0,0 +1,378 @@
#![allow(dead_code)]
//! # Variational Auto-Encoder (VAE) Models.
//!
//! Auto-encoder models compress their input to a usually smaller latent space
//! before expanding it back to its original shape. This results in the latent values
//! compressing the original information.
use crate::unet_2d_blocks::{
DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
UpDecoderBlock2D, UpDecoderBlock2DConfig,
};
use candle::{Result, Tensor};
use candle_nn as nn;
#[derive(Debug, Clone)]
struct EncoderConfig {
// down_block_types: DownEncoderBlock2D
block_out_channels: Vec<usize>,
layers_per_block: usize,
norm_num_groups: usize,
double_z: bool,
}
impl Default for EncoderConfig {
fn default() -> Self {
Self {
block_out_channels: vec![64],
layers_per_block: 2,
norm_num_groups: 32,
double_z: true,
}
}
}
#[derive(Debug)]
struct Encoder {
conv_in: nn::Conv2d,
down_blocks: Vec<DownEncoderBlock2D>,
mid_block: UNetMidBlock2D,
conv_norm_out: nn::GroupNorm,
conv_out: nn::Conv2d,
#[allow(dead_code)]
config: EncoderConfig,
}
impl Encoder {
fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: EncoderConfig,
) -> Result<Self> {
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 1,
};
let conv_in = nn::conv2d(
in_channels,
config.block_out_channels[0],
3,
conv_cfg,
vs.pp("conv_in"),
)?;
let mut down_blocks = vec![];
let vs_down_blocks = vs.pp("down_blocks");
for index in 0..config.block_out_channels.len() {
let out_channels = config.block_out_channels[index];
let in_channels = if index > 0 {
config.block_out_channels[index - 1]
} else {
config.block_out_channels[0]
};
let is_final = index + 1 == config.block_out_channels.len();
let cfg = DownEncoderBlock2DConfig {
num_layers: config.layers_per_block,
resnet_eps: 1e-6,
resnet_groups: config.norm_num_groups,
add_downsample: !is_final,
downsample_padding: 0,
..Default::default()
};
let down_block = DownEncoderBlock2D::new(
vs_down_blocks.pp(&index.to_string()),
in_channels,
out_channels,
cfg,
)?;
down_blocks.push(down_block)
}
let last_block_out_channels = *config.block_out_channels.last().unwrap();
let mid_cfg = UNetMidBlock2DConfig {
resnet_eps: 1e-6,
output_scale_factor: 1.,
attn_num_head_channels: None,
resnet_groups: Some(config.norm_num_groups),
..Default::default()
};
let mid_block =
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
last_block_out_channels,
1e-6,
vs.pp("conv_norm_out"),
)?;
let conv_out_channels = if config.double_z {
2 * out_channels
} else {
out_channels
};
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_out = nn::conv2d(
last_block_out_channels,
conv_out_channels,
3,
conv_cfg,
vs.pp("conv_out"),
)?;
Ok(Self {
conv_in,
down_blocks,
mid_block,
conv_norm_out,
conv_out,
config,
})
}
}
impl Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.conv_in.forward(xs)?;
for down_block in self.down_blocks.iter() {
xs = down_block.forward(&xs)?
}
let xs = self.mid_block.forward(&xs, None)?;
let xs = self.conv_norm_out.forward(&xs)?;
let xs = nn::ops::silu(&xs)?;
self.conv_out.forward(&xs)
}
}
#[derive(Debug, Clone)]
struct DecoderConfig {
// up_block_types: UpDecoderBlock2D
block_out_channels: Vec<usize>,
layers_per_block: usize,
norm_num_groups: usize,
}
impl Default for DecoderConfig {
fn default() -> Self {
Self {
block_out_channels: vec![64],
layers_per_block: 2,
norm_num_groups: 32,
}
}
}
#[derive(Debug)]
struct Decoder {
conv_in: nn::Conv2d,
up_blocks: Vec<UpDecoderBlock2D>,
mid_block: UNetMidBlock2D,
conv_norm_out: nn::GroupNorm,
conv_out: nn::Conv2d,
#[allow(dead_code)]
config: DecoderConfig,
}
impl Decoder {
fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: DecoderConfig,
) -> Result<Self> {
let n_block_out_channels = config.block_out_channels.len();
let last_block_out_channels = *config.block_out_channels.last().unwrap();
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 1,
};
let conv_in = nn::conv2d(
in_channels,
last_block_out_channels,
3,
conv_cfg,
vs.pp("conv_in"),
)?;
let mid_cfg = UNetMidBlock2DConfig {
resnet_eps: 1e-6,
output_scale_factor: 1.,
attn_num_head_channels: None,
resnet_groups: Some(config.norm_num_groups),
..Default::default()
};
let mid_block =
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
let mut up_blocks = vec![];
let vs_up_blocks = vs.pp("up_blocks");
let reversed_block_out_channels: Vec<_> =
config.block_out_channels.iter().copied().rev().collect();
for index in 0..n_block_out_channels {
let out_channels = reversed_block_out_channels[index];
let in_channels = if index > 0 {
reversed_block_out_channels[index - 1]
} else {
reversed_block_out_channels[0]
};
let is_final = index + 1 == n_block_out_channels;
let cfg = UpDecoderBlock2DConfig {
num_layers: config.layers_per_block + 1,
resnet_eps: 1e-6,
resnet_groups: config.norm_num_groups,
add_upsample: !is_final,
..Default::default()
};
let up_block = UpDecoderBlock2D::new(
vs_up_blocks.pp(&index.to_string()),
in_channels,
out_channels,
cfg,
)?;
up_blocks.push(up_block)
}
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
config.block_out_channels[0],
1e-6,
vs.pp("conv_norm_out"),
)?;
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_out = nn::conv2d(
config.block_out_channels[0],
out_channels,
3,
conv_cfg,
vs.pp("conv_out"),
)?;
Ok(Self {
conv_in,
up_blocks,
mid_block,
conv_norm_out,
conv_out,
config,
})
}
}
impl Decoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?;
for up_block in self.up_blocks.iter() {
xs = up_block.forward(&xs)?
}
let xs = self.conv_norm_out.forward(&xs)?;
let xs = nn::ops::silu(&xs)?;
self.conv_out.forward(&xs)
}
}
#[derive(Debug, Clone)]
pub struct AutoEncoderKLConfig {
pub block_out_channels: Vec<usize>,
pub layers_per_block: usize,
pub latent_channels: usize,
pub norm_num_groups: usize,
}
impl Default for AutoEncoderKLConfig {
fn default() -> Self {
Self {
block_out_channels: vec![64],
layers_per_block: 1,
latent_channels: 4,
norm_num_groups: 32,
}
}
}
pub struct DiagonalGaussianDistribution {
mean: Tensor,
std: Tensor,
}
impl DiagonalGaussianDistribution {
pub fn new(parameters: &Tensor) -> Result<Self> {
let mut parameters = parameters.chunk(2, 1)?.into_iter();
let mean = parameters.next().unwrap();
let logvar = parameters.next().unwrap();
let std = (logvar * 0.5)?.exp()?;
Ok(DiagonalGaussianDistribution { mean, std })
}
pub fn sample(&self) -> Result<Tensor> {
let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device());
&self.mean + &self.std * sample
}
}
// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485
// This implementation is specific to the config used in stable-diffusion-v1-5
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
#[derive(Debug)]
pub struct AutoEncoderKL {
encoder: Encoder,
decoder: Decoder,
quant_conv: nn::Conv2d,
post_quant_conv: nn::Conv2d,
pub config: AutoEncoderKLConfig,
}
impl AutoEncoderKL {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: AutoEncoderKLConfig,
) -> Result<Self> {
let latent_channels = config.latent_channels;
let encoder_cfg = EncoderConfig {
block_out_channels: config.block_out_channels.clone(),
layers_per_block: config.layers_per_block,
norm_num_groups: config.norm_num_groups,
double_z: true,
};
let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?;
let decoder_cfg = DecoderConfig {
block_out_channels: config.block_out_channels.clone(),
layers_per_block: config.layers_per_block,
norm_num_groups: config.norm_num_groups,
};
let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
let conv_cfg = Default::default();
let quant_conv = nn::conv2d(
2 * latent_channels,
2 * latent_channels,
1,
conv_cfg,
vs.pp("quant_conv"),
)?;
let post_quant_conv = nn::conv2d(
latent_channels,
latent_channels,
1,
conv_cfg,
vs.pp("post_quant_conv"),
)?;
Ok(Self {
encoder,
decoder,
quant_conv,
post_quant_conv,
config,
})
}
/// Returns the distribution in the latent space.
pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
let xs = self.encoder.forward(xs)?;
let parameters = self.quant_conv.forward(&xs)?;
DiagonalGaussianDistribution::new(&parameters)
}
/// Takes as input some sampled values.
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.post_quant_conv.forward(xs)?;
self.decoder.forward(&xs)
}
}