mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a stable diffusion example (#328)
* Start adding a stable-diffusion example. * Proper computation of the causal mask. * Add the chunk operation. * Work in progress: port the attention module. * Add some dummy modules for conv2d and group-norm, get the attention module to compile. * Re-enable the 2d convolution. * Add the embeddings module. * Add the resnet module. * Add the unet blocks. * Add the unet. * And add the variational auto-encoder. * Use the pad function from utils.
This commit is contained in:
@ -548,6 +548,32 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a tensor into the specified number of chunks, this may return less chunks than
|
||||
/// specificed.
|
||||
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
|
||||
let dim = dim.to_index(self.shape(), "chunk")?;
|
||||
let size = self.dim(dim)?;
|
||||
if size < chunks {
|
||||
(0..size).map(|i| self.narrow(dim, i, 1)).collect()
|
||||
} else {
|
||||
let chunk_size = size / chunks;
|
||||
let cnt_additional = size % chunks;
|
||||
let mut tensors = vec![];
|
||||
let mut sum_chunk_size = 0;
|
||||
for i in 0..chunks {
|
||||
let chunk_size = if i < cnt_additional {
|
||||
chunk_size + 1
|
||||
} else {
|
||||
chunk_size
|
||||
};
|
||||
let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
|
||||
tensors.push(tensor);
|
||||
sum_chunk_size += chunk_size
|
||||
}
|
||||
Ok(tensors)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + len`.
|
||||
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
||||
|
445
candle-examples/examples/stable-diffusion/attention.rs
Normal file
445
candle-examples/examples/stable-diffusion/attention.rs
Normal file
@ -0,0 +1,445 @@
|
||||
#![allow(dead_code)]
|
||||
//! Attention Based Building Blocks
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct GeGlu {
|
||||
proj: nn::Linear,
|
||||
}
|
||||
|
||||
impl GeGlu {
|
||||
fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
|
||||
let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
|
||||
Ok(Self { proj })
|
||||
}
|
||||
}
|
||||
|
||||
impl GeGlu {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
|
||||
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
|
||||
}
|
||||
}
|
||||
|
||||
/// A feed-forward layer.
|
||||
#[derive(Debug)]
|
||||
struct FeedForward {
|
||||
project_in: GeGlu,
|
||||
linear: nn::Linear,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
// The glu parameter in the python code is unused?
|
||||
// https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347
|
||||
/// Creates a new feed-forward layer based on some given input dimension, some
|
||||
/// output dimension, and a multiplier to be used for the intermediary layer.
|
||||
fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
|
||||
let inner_dim = dim * mult;
|
||||
let dim_out = dim_out.unwrap_or(dim);
|
||||
let vs = vs.pp("net");
|
||||
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
|
||||
let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
|
||||
Ok(Self { project_in, linear })
|
||||
}
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.project_in.forward(xs)?;
|
||||
self.linear.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CrossAttention {
|
||||
to_q: nn::Linear,
|
||||
to_k: nn::Linear,
|
||||
to_v: nn::Linear,
|
||||
to_out: nn::Linear,
|
||||
heads: usize,
|
||||
scale: f64,
|
||||
slice_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl CrossAttention {
|
||||
// Defaults should be heads = 8, dim_head = 64, context_dim = None
|
||||
fn new(
|
||||
vs: nn::VarBuilder,
|
||||
query_dim: usize,
|
||||
context_dim: Option<usize>,
|
||||
heads: usize,
|
||||
dim_head: usize,
|
||||
slice_size: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = dim_head * heads;
|
||||
let context_dim = context_dim.unwrap_or(query_dim);
|
||||
let scale = 1.0 / f64::sqrt(dim_head as f64);
|
||||
let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
|
||||
let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
|
||||
let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
|
||||
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
|
||||
Ok(Self {
|
||||
to_q,
|
||||
to_k,
|
||||
to_v,
|
||||
to_out,
|
||||
heads,
|
||||
scale,
|
||||
slice_size,
|
||||
})
|
||||
}
|
||||
|
||||
fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (batch_size, seq_len, dim) = xs.dims3()?;
|
||||
xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((batch_size * self.heads, seq_len, dim / self.heads))
|
||||
}
|
||||
|
||||
fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (batch_size, seq_len, dim) = xs.dims3()?;
|
||||
xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((batch_size / self.heads, seq_len, dim * self.heads))
|
||||
}
|
||||
|
||||
fn sliced_attention(
|
||||
&self,
|
||||
query: &Tensor,
|
||||
key: &Tensor,
|
||||
value: &Tensor,
|
||||
slice_size: usize,
|
||||
) -> Result<Tensor> {
|
||||
let batch_size_attention = query.dim(0)?;
|
||||
let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
|
||||
|
||||
for i in 0..batch_size_attention / slice_size {
|
||||
let start_idx = i * slice_size;
|
||||
let end_idx = (i + 1) * slice_size;
|
||||
|
||||
let xs = query
|
||||
.i(start_idx..end_idx)?
|
||||
.matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?;
|
||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
|
||||
hidden_states.push(xs)
|
||||
}
|
||||
let hidden_states = Tensor::stack(&hidden_states, 0)?;
|
||||
self.reshape_batch_dim_to_heads(&hidden_states)
|
||||
}
|
||||
|
||||
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
|
||||
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
|
||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
|
||||
self.reshape_batch_dim_to_heads(&xs)
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let query = self.to_q.forward(xs)?;
|
||||
let context = context.unwrap_or(xs);
|
||||
let key = self.to_k.forward(context)?;
|
||||
let value = self.to_v.forward(context)?;
|
||||
let query = self.reshape_heads_to_batch_dim(&query)?;
|
||||
let key = self.reshape_heads_to_batch_dim(&key)?;
|
||||
let value = self.reshape_heads_to_batch_dim(&value)?;
|
||||
let xs = match self.slice_size {
|
||||
None => self.attention(&query, &key, &value)?,
|
||||
Some(slice_size) => {
|
||||
if query.dim(0)? / slice_size <= 1 {
|
||||
self.attention(&query, &key, &value)?
|
||||
} else {
|
||||
self.sliced_attention(&query, &key, &value, slice_size)?
|
||||
}
|
||||
}
|
||||
};
|
||||
self.to_out.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
/// A basic Transformer block.
|
||||
#[derive(Debug)]
|
||||
struct BasicTransformerBlock {
|
||||
attn1: CrossAttention,
|
||||
ff: FeedForward,
|
||||
attn2: CrossAttention,
|
||||
norm1: nn::LayerNorm,
|
||||
norm2: nn::LayerNorm,
|
||||
norm3: nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl BasicTransformerBlock {
|
||||
fn new(
|
||||
vs: nn::VarBuilder,
|
||||
dim: usize,
|
||||
n_heads: usize,
|
||||
d_head: usize,
|
||||
context_dim: Option<usize>,
|
||||
sliced_attention_size: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let attn1 = CrossAttention::new(
|
||||
vs.pp("attn1"),
|
||||
dim,
|
||||
None,
|
||||
n_heads,
|
||||
d_head,
|
||||
sliced_attention_size,
|
||||
)?;
|
||||
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
|
||||
let attn2 = CrossAttention::new(
|
||||
vs.pp("attn2"),
|
||||
dim,
|
||||
context_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
sliced_attention_size,
|
||||
)?;
|
||||
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
|
||||
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
|
||||
let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
|
||||
Ok(Self {
|
||||
attn1,
|
||||
ff,
|
||||
attn2,
|
||||
norm1,
|
||||
norm2,
|
||||
norm3,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
|
||||
let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
|
||||
self.ff.forward(&self.norm3.forward(&xs)?)? + xs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct SpatialTransformerConfig {
|
||||
pub depth: usize,
|
||||
pub num_groups: usize,
|
||||
pub context_dim: Option<usize>,
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
}
|
||||
|
||||
impl Default for SpatialTransformerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
depth: 1,
|
||||
num_groups: 32,
|
||||
context_dim: None,
|
||||
sliced_attention_size: None,
|
||||
use_linear_projection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Proj {
|
||||
Conv2d(nn::Conv2d),
|
||||
Linear(nn::Linear),
|
||||
}
|
||||
|
||||
// Aka Transformer2DModel
|
||||
#[derive(Debug)]
|
||||
pub struct SpatialTransformer {
|
||||
norm: nn::GroupNorm,
|
||||
proj_in: Proj,
|
||||
transformer_blocks: Vec<BasicTransformerBlock>,
|
||||
proj_out: Proj,
|
||||
pub config: SpatialTransformerConfig,
|
||||
}
|
||||
|
||||
impl SpatialTransformer {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
n_heads: usize,
|
||||
d_head: usize,
|
||||
config: SpatialTransformerConfig,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = n_heads * d_head;
|
||||
let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?;
|
||||
let proj_in = if config.use_linear_projection {
|
||||
Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?)
|
||||
} else {
|
||||
Proj::Conv2d(nn::conv2d(
|
||||
in_channels,
|
||||
inner_dim,
|
||||
1,
|
||||
Default::default(),
|
||||
vs.pp("proj_in"),
|
||||
)?)
|
||||
};
|
||||
let mut transformer_blocks = vec![];
|
||||
let vs_tb = vs.pp("transformer_blocks");
|
||||
for index in 0..config.depth {
|
||||
let tb = BasicTransformerBlock::new(
|
||||
vs_tb.pp(&index.to_string()),
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
config.context_dim,
|
||||
config.sliced_attention_size,
|
||||
)?;
|
||||
transformer_blocks.push(tb)
|
||||
}
|
||||
let proj_out = if config.use_linear_projection {
|
||||
Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?)
|
||||
} else {
|
||||
Proj::Conv2d(nn::conv2d(
|
||||
inner_dim,
|
||||
in_channels,
|
||||
1,
|
||||
Default::default(),
|
||||
vs.pp("proj_out"),
|
||||
)?)
|
||||
};
|
||||
Ok(Self {
|
||||
norm,
|
||||
proj_in,
|
||||
transformer_blocks,
|
||||
proj_out,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let (batch, _channel, height, weight) = xs.dims4()?;
|
||||
let residual = xs;
|
||||
let xs = self.norm.forward(xs)?;
|
||||
let (inner_dim, xs) = match &self.proj_in {
|
||||
Proj::Conv2d(p) => {
|
||||
let xs = p.forward(&xs)?;
|
||||
let inner_dim = xs.dim(1)?;
|
||||
let xs = xs
|
||||
.transpose(1, 2)?
|
||||
.t()?
|
||||
.reshape((batch, height * weight, inner_dim))?;
|
||||
(inner_dim, xs)
|
||||
}
|
||||
Proj::Linear(p) => {
|
||||
let inner_dim = xs.dim(1)?;
|
||||
let xs = xs
|
||||
.transpose(1, 2)?
|
||||
.t()?
|
||||
.reshape((batch, height * weight, inner_dim))?;
|
||||
(inner_dim, p.forward(&xs)?)
|
||||
}
|
||||
};
|
||||
let mut xs = xs;
|
||||
for block in self.transformer_blocks.iter() {
|
||||
xs = block.forward(&xs, context)?
|
||||
}
|
||||
let xs = match &self.proj_out {
|
||||
Proj::Conv2d(p) => p.forward(
|
||||
&xs.reshape((batch, height, weight, inner_dim))?
|
||||
.t()?
|
||||
.transpose(1, 2)?,
|
||||
)?,
|
||||
Proj::Linear(p) => p
|
||||
.forward(&xs)?
|
||||
.reshape((batch, height, weight, inner_dim))?
|
||||
.t()?
|
||||
.transpose(1, 2)?,
|
||||
};
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for an attention block.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct AttentionBlockConfig {
|
||||
pub num_head_channels: Option<usize>,
|
||||
pub num_groups: usize,
|
||||
pub rescale_output_factor: f64,
|
||||
pub eps: f64,
|
||||
}
|
||||
|
||||
impl Default for AttentionBlockConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_head_channels: None,
|
||||
num_groups: 32,
|
||||
rescale_output_factor: 1.,
|
||||
eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AttentionBlock {
|
||||
group_norm: nn::GroupNorm,
|
||||
query: nn::Linear,
|
||||
key: nn::Linear,
|
||||
value: nn::Linear,
|
||||
proj_attn: nn::Linear,
|
||||
channels: usize,
|
||||
num_heads: usize,
|
||||
config: AttentionBlockConfig,
|
||||
}
|
||||
|
||||
impl AttentionBlock {
|
||||
pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
|
||||
let num_head_channels = config.num_head_channels.unwrap_or(channels);
|
||||
let num_heads = channels / num_head_channels;
|
||||
let group_norm =
|
||||
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
|
||||
let query = nn::linear(channels, channels, vs.pp("query"))?;
|
||||
let key = nn::linear(channels, channels, vs.pp("key"))?;
|
||||
let value = nn::linear(channels, channels, vs.pp("value"))?;
|
||||
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
|
||||
Ok(Self {
|
||||
group_norm,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
proj_attn,
|
||||
channels,
|
||||
num_heads,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let (batch, t, h_times_d) = xs.dims3()?;
|
||||
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
|
||||
.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
impl AttentionBlock {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let (batch, channel, height, width) = xs.dims4()?;
|
||||
let xs = self
|
||||
.group_norm
|
||||
.forward(xs)?
|
||||
.reshape((batch, channel, height * width))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let query_proj = self.query.forward(&xs)?;
|
||||
let key_proj = self.key.forward(&xs)?;
|
||||
let value_proj = self.value.forward(&xs)?;
|
||||
|
||||
let query_states = self.transpose_for_scores(query_proj)?;
|
||||
let key_states = self.transpose_for_scores(key_proj)?;
|
||||
let value_states = self.transpose_for_scores(value_proj)?;
|
||||
|
||||
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
|
||||
let attention_scores =
|
||||
// TODO: Check that this needs two multiplication by `scale`.
|
||||
(query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
|
||||
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
|
||||
|
||||
let xs = attention_probs.matmul(&value_states)?;
|
||||
let xs = xs.transpose(1, 2)?.contiguous()?;
|
||||
let xs = xs.flatten_from(D::Minus2)?;
|
||||
let xs = self
|
||||
.proj_attn
|
||||
.forward(&xs)?
|
||||
.t()?
|
||||
.reshape((batch, channel, height, width))?;
|
||||
(xs + residual)? / self.config.rescale_output_factor
|
||||
}
|
||||
}
|
304
candle-examples/examples/stable-diffusion/clip.rs
Normal file
304
candle-examples/examples/stable-diffusion/clip.rs
Normal file
@ -0,0 +1,304 @@
|
||||
#![allow(dead_code)]
|
||||
//! Contrastive Language-Image Pre-Training
|
||||
//!
|
||||
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
//! pairs of images with related texts.
|
||||
//!
|
||||
//! https://github.com/openai/CLIP
|
||||
use candle::{Device, Result, Tensor, D};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Activation {
|
||||
QuickGelu,
|
||||
Gelu,
|
||||
}
|
||||
|
||||
impl Activation {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Activation::QuickGelu => xs * crate::utils::sigmoid(&(xs * 1.702f64)?)?,
|
||||
Activation::Gelu => xs.gelu(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
embed_dim: usize, // aka config.hidden_size
|
||||
activation: Activation, // aka config.hidden_act
|
||||
intermediate_size: usize,
|
||||
max_position_embeddings: usize,
|
||||
// The character to use for padding, use EOS when not set.
|
||||
pad_with: Option<String>,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
#[allow(dead_code)]
|
||||
projection_dim: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// The config details can be found in the "text_config" section of this json file:
|
||||
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
||||
pub fn v1_5() -> Self {
|
||||
Self {
|
||||
vocab_size: 49408,
|
||||
embed_dim: 768,
|
||||
intermediate_size: 3072,
|
||||
max_position_embeddings: 77,
|
||||
pad_with: None,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
projection_dim: 768,
|
||||
activation: Activation::QuickGelu,
|
||||
}
|
||||
}
|
||||
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json
|
||||
pub fn v2_1() -> Self {
|
||||
Self {
|
||||
vocab_size: 49408,
|
||||
embed_dim: 1024,
|
||||
intermediate_size: 4096,
|
||||
max_position_embeddings: 77,
|
||||
pad_with: Some("!".to_string()),
|
||||
num_hidden_layers: 23,
|
||||
num_attention_heads: 16,
|
||||
projection_dim: 512,
|
||||
activation: Activation::Gelu,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CLIP Text Model
|
||||
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py
|
||||
#[derive(Debug)]
|
||||
struct ClipTextEmbeddings {
|
||||
token_embedding: candle_nn::Embedding,
|
||||
position_embedding: candle_nn::Embedding,
|
||||
position_ids: Tensor,
|
||||
}
|
||||
|
||||
impl ClipTextEmbeddings {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let token_embedding =
|
||||
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
|
||||
let position_embedding = candle_nn::embedding(
|
||||
c.max_position_embeddings,
|
||||
c.embed_dim,
|
||||
vs.pp("position_embedding"),
|
||||
)?;
|
||||
let position_ids =
|
||||
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(1)?;
|
||||
Ok(ClipTextEmbeddings {
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
position_ids,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ClipTextEmbeddings {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let token_embedding = self.token_embedding.forward(xs)?;
|
||||
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
||||
token_embedding + position_embedding
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ClipAttention {
|
||||
k_proj: candle_nn::Linear,
|
||||
v_proj: candle_nn::Linear,
|
||||
q_proj: candle_nn::Linear,
|
||||
out_proj: candle_nn::Linear,
|
||||
head_dim: usize,
|
||||
scale: f64,
|
||||
num_attention_heads: usize,
|
||||
}
|
||||
|
||||
impl ClipAttention {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let embed_dim = c.embed_dim;
|
||||
let num_attention_heads = c.num_attention_heads;
|
||||
let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
|
||||
let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
|
||||
let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
|
||||
let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
|
||||
let head_dim = embed_dim / num_attention_heads;
|
||||
let scale = (head_dim as f64).powf(-0.5);
|
||||
Ok(ClipAttention {
|
||||
k_proj,
|
||||
v_proj,
|
||||
q_proj,
|
||||
out_proj,
|
||||
head_dim,
|
||||
scale,
|
||||
num_attention_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
|
||||
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let (bsz, seq_len, embed_dim) = xs.dims3()?;
|
||||
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
|
||||
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
|
||||
let query_states = self
|
||||
.shape(&query_states, seq_len, bsz)?
|
||||
.reshape(proj_shape)?;
|
||||
let key_states = self
|
||||
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
|
||||
.reshape(proj_shape)?;
|
||||
let value_states = self
|
||||
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
|
||||
.reshape(proj_shape)?;
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||
|
||||
let src_len = key_states.dim(1)?;
|
||||
let attn_weights =
|
||||
(attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
|
||||
+ causal_attention_mask)?;
|
||||
let attn_weights =
|
||||
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
|
||||
let attn_output = attn_weights.matmul(&value_states)?;
|
||||
let attn_output = attn_output
|
||||
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((bsz, seq_len, embed_dim))?;
|
||||
self.out_proj.forward(&attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ClipMlp {
|
||||
fc1: candle_nn::Linear,
|
||||
fc2: candle_nn::Linear,
|
||||
activation: Activation,
|
||||
}
|
||||
|
||||
impl ClipMlp {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?;
|
||||
let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?;
|
||||
Ok(ClipMlp {
|
||||
fc1,
|
||||
fc2,
|
||||
activation: c.activation,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ClipMlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?;
|
||||
self.fc2.forward(&self.activation.forward(&xs)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ClipEncoderLayer {
|
||||
self_attn: ClipAttention,
|
||||
layer_norm1: candle_nn::LayerNorm,
|
||||
mlp: ClipMlp,
|
||||
layer_norm2: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl ClipEncoderLayer {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
|
||||
let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?;
|
||||
let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
|
||||
let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?;
|
||||
Ok(ClipEncoderLayer {
|
||||
self_attn,
|
||||
layer_norm1,
|
||||
mlp,
|
||||
layer_norm2,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.layer_norm1.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
|
||||
let xs = (xs + residual)?;
|
||||
|
||||
let residual = &xs;
|
||||
let xs = self.layer_norm2.forward(&xs)?;
|
||||
let xs = self.mlp.forward(&xs)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ClipEncoder {
|
||||
layers: Vec<ClipEncoderLayer>,
|
||||
}
|
||||
|
||||
impl ClipEncoder {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let vs = vs.pp("layers");
|
||||
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
|
||||
for index in 0..c.num_hidden_layers {
|
||||
let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(ClipEncoder { layers })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, causal_attention_mask)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
/// A CLIP transformer based model.
|
||||
#[derive(Debug)]
|
||||
pub struct ClipTextTransformer {
|
||||
embeddings: ClipTextEmbeddings,
|
||||
encoder: ClipEncoder,
|
||||
final_layer_norm: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl ClipTextTransformer {
|
||||
pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let vs = vs.pp("text_model");
|
||||
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
|
||||
let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
|
||||
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
|
||||
Ok(ClipTextTransformer {
|
||||
embeddings,
|
||||
encoder,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
|
||||
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||
mask.broadcast_as((bsz, seq_len, seq_len))
|
||||
}
|
||||
}
|
||||
|
||||
impl ClipTextTransformer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (bsz, seq_len) = xs.dims2()?;
|
||||
let xs = self.embeddings.forward(xs)?;
|
||||
let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
|
||||
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
|
||||
self.final_layer_norm.forward(&xs)
|
||||
}
|
||||
}
|
65
candle-examples/examples/stable-diffusion/embeddings.rs
Normal file
65
candle-examples/examples/stable-diffusion/embeddings.rs
Normal file
@ -0,0 +1,65 @@
|
||||
#![allow(dead_code)]
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TimestepEmbedding {
|
||||
linear_1: nn::Linear,
|
||||
linear_2: nn::Linear,
|
||||
}
|
||||
|
||||
impl TimestepEmbedding {
|
||||
// act_fn: "silu"
|
||||
pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
|
||||
let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
|
||||
let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
|
||||
Ok(Self { linear_1, linear_2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl TimestepEmbedding {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
|
||||
self.linear_2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Timesteps {
|
||||
num_channels: usize,
|
||||
flip_sin_to_cos: bool,
|
||||
downscale_freq_shift: f64,
|
||||
}
|
||||
|
||||
impl Timesteps {
|
||||
pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
|
||||
Self {
|
||||
num_channels,
|
||||
flip_sin_to_cos,
|
||||
downscale_freq_shift,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Timesteps {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let half_dim = (self.num_channels / 2) as u32;
|
||||
let exponent =
|
||||
(Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
|
||||
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
|
||||
let emb = exponent.exp()?;
|
||||
// emb = timesteps[:, None].float() * emb[None, :]
|
||||
let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?;
|
||||
let (cos, sin) = (emb.cos()?, emb.sin()?);
|
||||
let emb = if self.flip_sin_to_cos {
|
||||
Tensor::cat(&[&cos, &sin], D::Minus1)?
|
||||
} else {
|
||||
Tensor::cat(&[&sin, &cos], D::Minus1)?
|
||||
};
|
||||
if self.num_channels % 2 == 1 {
|
||||
crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None)
|
||||
} else {
|
||||
Ok(emb)
|
||||
}
|
||||
}
|
||||
}
|
30
candle-examples/examples/stable-diffusion/main.rs
Normal file
30
candle-examples/examples/stable-diffusion/main.rs
Normal file
@ -0,0 +1,30 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod attention;
|
||||
mod clip;
|
||||
mod embeddings;
|
||||
mod resnet;
|
||||
mod unet_2d;
|
||||
mod unet_2d_blocks;
|
||||
mod utils;
|
||||
mod vae;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let _args = Args::parse();
|
||||
Ok(())
|
||||
}
|
129
candle-examples/examples/stable-diffusion/resnet.rs
Normal file
129
candle-examples/examples/stable-diffusion/resnet.rs
Normal file
@ -0,0 +1,129 @@
|
||||
#![allow(dead_code)]
|
||||
//! ResNet Building Blocks
|
||||
//!
|
||||
//! Some Residual Network blocks used in UNet models.
|
||||
//!
|
||||
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
|
||||
//! https://arxiv.org/abs/1512.03385
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
/// Configuration for a ResNet block.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ResnetBlock2DConfig {
|
||||
/// The number of output channels, defaults to the number of input channels.
|
||||
pub out_channels: Option<usize>,
|
||||
pub temb_channels: Option<usize>,
|
||||
/// The number of groups to use in group normalization.
|
||||
pub groups: usize,
|
||||
pub groups_out: Option<usize>,
|
||||
/// The epsilon to be used in the group normalization operations.
|
||||
pub eps: f64,
|
||||
/// Whether to use a 2D convolution in the skip connection. When using None,
|
||||
/// such a convolution is used if the number of input channels is different from
|
||||
/// the number of output channels.
|
||||
pub use_in_shortcut: Option<bool>,
|
||||
// non_linearity: silu
|
||||
/// The final output is scaled by dividing by this value.
|
||||
pub output_scale_factor: f64,
|
||||
}
|
||||
|
||||
impl Default for ResnetBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
out_channels: None,
|
||||
temb_channels: Some(512),
|
||||
groups: 32,
|
||||
groups_out: None,
|
||||
eps: 1e-6,
|
||||
use_in_shortcut: None,
|
||||
output_scale_factor: 1.,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ResnetBlock2D {
|
||||
norm1: nn::GroupNorm,
|
||||
conv1: nn::Conv2d,
|
||||
norm2: nn::GroupNorm,
|
||||
conv2: nn::Conv2d,
|
||||
time_emb_proj: Option<nn::Linear>,
|
||||
conv_shortcut: Option<nn::Conv2d>,
|
||||
config: ResnetBlock2DConfig,
|
||||
}
|
||||
|
||||
impl ResnetBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
config: ResnetBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let out_channels = config.out_channels.unwrap_or(in_channels);
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
};
|
||||
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
||||
let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
||||
let groups_out = config.groups_out.unwrap_or(config.groups);
|
||||
let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
|
||||
let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
|
||||
let use_in_shortcut = config
|
||||
.use_in_shortcut
|
||||
.unwrap_or(in_channels != out_channels);
|
||||
let conv_shortcut = if use_in_shortcut {
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 0,
|
||||
};
|
||||
Some(nn::conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg,
|
||||
vs.pp("conv_shortcut"),
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let time_emb_proj = match config.temb_channels {
|
||||
None => None,
|
||||
Some(temb_channels) => Some(nn::linear(
|
||||
temb_channels,
|
||||
out_channels,
|
||||
vs.pp("time_emb_proj"),
|
||||
)?),
|
||||
};
|
||||
Ok(Self {
|
||||
norm1,
|
||||
conv1,
|
||||
norm2,
|
||||
conv2,
|
||||
time_emb_proj,
|
||||
config,
|
||||
conv_shortcut,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
|
||||
let shortcut_xs = match &self.conv_shortcut {
|
||||
Some(conv_shortcut) => conv_shortcut.forward(xs)?,
|
||||
None => xs.clone(),
|
||||
};
|
||||
let xs = self.norm1.forward(xs)?;
|
||||
let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
|
||||
let xs = match (temb, &self.time_emb_proj) {
|
||||
(Some(temb), Some(time_emb_proj)) => time_emb_proj
|
||||
.forward(&nn::ops::silu(temb)?)?
|
||||
.unsqueeze(D::Minus1)?
|
||||
.unsqueeze(D::Minus1)?
|
||||
.add(&xs)?,
|
||||
_ => xs,
|
||||
};
|
||||
let xs = self
|
||||
.conv2
|
||||
.forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
|
||||
(shortcut_xs + xs)? / self.config.output_scale_factor
|
||||
}
|
||||
}
|
383
candle-examples/examples/stable-diffusion/unet_2d.rs
Normal file
383
candle-examples/examples/stable-diffusion/unet_2d.rs
Normal file
@ -0,0 +1,383 @@
|
||||
#![allow(dead_code)]
|
||||
//! 2D UNet Denoising Models
|
||||
//!
|
||||
//! The 2D Unet models take as input a noisy sample and the current diffusion
|
||||
//! timestep and return a denoised version of the input.
|
||||
use crate::embeddings::{TimestepEmbedding, Timesteps};
|
||||
use crate::unet_2d_blocks::*;
|
||||
use candle::{DType, Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct BlockConfig {
|
||||
pub out_channels: usize,
|
||||
pub use_cross_attn: bool,
|
||||
pub attention_head_dim: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UNet2DConditionModelConfig {
|
||||
pub center_input_sample: bool,
|
||||
pub flip_sin_to_cos: bool,
|
||||
pub freq_shift: f64,
|
||||
pub blocks: Vec<BlockConfig>,
|
||||
pub layers_per_block: usize,
|
||||
pub downsample_padding: usize,
|
||||
pub mid_block_scale_factor: f64,
|
||||
pub norm_num_groups: usize,
|
||||
pub norm_eps: f64,
|
||||
pub cross_attention_dim: usize,
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
}
|
||||
|
||||
impl Default for UNet2DConditionModelConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
center_input_sample: false,
|
||||
flip_sin_to_cos: true,
|
||||
freq_shift: 0.,
|
||||
blocks: vec![
|
||||
BlockConfig {
|
||||
out_channels: 320,
|
||||
use_cross_attn: true,
|
||||
attention_head_dim: 8,
|
||||
},
|
||||
BlockConfig {
|
||||
out_channels: 640,
|
||||
use_cross_attn: true,
|
||||
attention_head_dim: 8,
|
||||
},
|
||||
BlockConfig {
|
||||
out_channels: 1280,
|
||||
use_cross_attn: true,
|
||||
attention_head_dim: 8,
|
||||
},
|
||||
BlockConfig {
|
||||
out_channels: 1280,
|
||||
use_cross_attn: false,
|
||||
attention_head_dim: 8,
|
||||
},
|
||||
],
|
||||
layers_per_block: 2,
|
||||
downsample_padding: 1,
|
||||
mid_block_scale_factor: 1.,
|
||||
norm_num_groups: 32,
|
||||
norm_eps: 1e-5,
|
||||
cross_attention_dim: 1280,
|
||||
sliced_attention_size: None,
|
||||
use_linear_projection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum UNetDownBlock {
|
||||
Basic(DownBlock2D),
|
||||
CrossAttn(CrossAttnDownBlock2D),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum UNetUpBlock {
|
||||
Basic(UpBlock2D),
|
||||
CrossAttn(CrossAttnUpBlock2D),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UNet2DConditionModel {
|
||||
conv_in: nn::Conv2d,
|
||||
time_proj: Timesteps,
|
||||
time_embedding: TimestepEmbedding,
|
||||
down_blocks: Vec<UNetDownBlock>,
|
||||
mid_block: UNetMidBlock2DCrossAttn,
|
||||
up_blocks: Vec<UNetUpBlock>,
|
||||
conv_norm_out: nn::GroupNorm,
|
||||
conv_out: nn::Conv2d,
|
||||
config: UNet2DConditionModelConfig,
|
||||
}
|
||||
|
||||
impl UNet2DConditionModel {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
config: UNet2DConditionModelConfig,
|
||||
) -> Result<Self> {
|
||||
let n_blocks = config.blocks.len();
|
||||
let b_channels = config.blocks[0].out_channels;
|
||||
let bl_channels = config.blocks.last().unwrap().out_channels;
|
||||
let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
|
||||
let time_embed_dim = b_channels * 4;
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
};
|
||||
let conv_in = nn::conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
|
||||
|
||||
let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
|
||||
let time_embedding =
|
||||
TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
|
||||
|
||||
let vs_db = vs.pp("down_blocks");
|
||||
let down_blocks = (0..n_blocks)
|
||||
.map(|i| {
|
||||
let BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
} = config.blocks[i];
|
||||
|
||||
// Enable automatic attention slicing if the config sliced_attention_size is set to 0.
|
||||
let sliced_attention_size = match config.sliced_attention_size {
|
||||
Some(0) => Some(attention_head_dim / 2),
|
||||
_ => config.sliced_attention_size,
|
||||
};
|
||||
|
||||
let in_channels = if i > 0 {
|
||||
config.blocks[i - 1].out_channels
|
||||
} else {
|
||||
b_channels
|
||||
};
|
||||
let db_cfg = DownBlock2DConfig {
|
||||
num_layers: config.layers_per_block,
|
||||
resnet_eps: config.norm_eps,
|
||||
resnet_groups: config.norm_num_groups,
|
||||
add_downsample: i < n_blocks - 1,
|
||||
downsample_padding: config.downsample_padding,
|
||||
..Default::default()
|
||||
};
|
||||
if use_cross_attn {
|
||||
let config = CrossAttnDownBlock2DConfig {
|
||||
downblock: db_cfg,
|
||||
attn_num_head_channels: attention_head_dim,
|
||||
cross_attention_dim: config.cross_attention_dim,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
};
|
||||
let block = CrossAttnDownBlock2D::new(
|
||||
vs_db.pp(&i.to_string()),
|
||||
in_channels,
|
||||
out_channels,
|
||||
Some(time_embed_dim),
|
||||
config,
|
||||
)?;
|
||||
Ok(UNetDownBlock::CrossAttn(block))
|
||||
} else {
|
||||
let block = DownBlock2D::new(
|
||||
vs_db.pp(&i.to_string()),
|
||||
in_channels,
|
||||
out_channels,
|
||||
Some(time_embed_dim),
|
||||
db_cfg,
|
||||
)?;
|
||||
Ok(UNetDownBlock::Basic(block))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let mid_cfg = UNetMidBlock2DCrossAttnConfig {
|
||||
resnet_eps: config.norm_eps,
|
||||
output_scale_factor: config.mid_block_scale_factor,
|
||||
cross_attn_dim: config.cross_attention_dim,
|
||||
attn_num_head_channels: bl_attention_head_dim,
|
||||
resnet_groups: Some(config.norm_num_groups),
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
..Default::default()
|
||||
};
|
||||
let mid_block = UNetMidBlock2DCrossAttn::new(
|
||||
vs.pp("mid_block"),
|
||||
bl_channels,
|
||||
Some(time_embed_dim),
|
||||
mid_cfg,
|
||||
)?;
|
||||
|
||||
let vs_ub = vs.pp("up_blocks");
|
||||
let up_blocks = (0..n_blocks)
|
||||
.map(|i| {
|
||||
let BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
} = config.blocks[n_blocks - 1 - i];
|
||||
|
||||
// Enable automatic attention slicing if the config sliced_attention_size is set to 0.
|
||||
let sliced_attention_size = match config.sliced_attention_size {
|
||||
Some(0) => Some(attention_head_dim / 2),
|
||||
_ => config.sliced_attention_size,
|
||||
};
|
||||
|
||||
let prev_out_channels = if i > 0 {
|
||||
config.blocks[n_blocks - i].out_channels
|
||||
} else {
|
||||
bl_channels
|
||||
};
|
||||
let in_channels = {
|
||||
let index = if i == n_blocks - 1 {
|
||||
0
|
||||
} else {
|
||||
n_blocks - i - 2
|
||||
};
|
||||
config.blocks[index].out_channels
|
||||
};
|
||||
let ub_cfg = UpBlock2DConfig {
|
||||
num_layers: config.layers_per_block + 1,
|
||||
resnet_eps: config.norm_eps,
|
||||
resnet_groups: config.norm_num_groups,
|
||||
add_upsample: i < n_blocks - 1,
|
||||
..Default::default()
|
||||
};
|
||||
if use_cross_attn {
|
||||
let config = CrossAttnUpBlock2DConfig {
|
||||
upblock: ub_cfg,
|
||||
attn_num_head_channels: attention_head_dim,
|
||||
cross_attention_dim: config.cross_attention_dim,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
};
|
||||
let block = CrossAttnUpBlock2D::new(
|
||||
vs_ub.pp(&i.to_string()),
|
||||
in_channels,
|
||||
prev_out_channels,
|
||||
out_channels,
|
||||
Some(time_embed_dim),
|
||||
config,
|
||||
)?;
|
||||
Ok(UNetUpBlock::CrossAttn(block))
|
||||
} else {
|
||||
let block = UpBlock2D::new(
|
||||
vs_ub.pp(&i.to_string()),
|
||||
in_channels,
|
||||
prev_out_channels,
|
||||
out_channels,
|
||||
Some(time_embed_dim),
|
||||
ub_cfg,
|
||||
)?;
|
||||
Ok(UNetUpBlock::Basic(block))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let conv_norm_out = nn::group_norm(
|
||||
config.norm_num_groups,
|
||||
b_channels,
|
||||
config.norm_eps,
|
||||
vs.pp("conv_norm_out"),
|
||||
)?;
|
||||
let conv_out = nn::conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
|
||||
Ok(Self {
|
||||
conv_in,
|
||||
time_proj,
|
||||
time_embedding,
|
||||
down_blocks,
|
||||
mid_block,
|
||||
up_blocks,
|
||||
conv_norm_out,
|
||||
conv_out,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl UNet2DConditionModel {
|
||||
pub fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
timestep: f64,
|
||||
encoder_hidden_states: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
|
||||
}
|
||||
|
||||
pub fn forward_with_additional_residuals(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
timestep: f64,
|
||||
encoder_hidden_states: &Tensor,
|
||||
down_block_additional_residuals: Option<&[Tensor]>,
|
||||
mid_block_additional_residual: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (bsize, _channels, height, width) = xs.dims4()?;
|
||||
let device = xs.device();
|
||||
let n_blocks = self.config.blocks.len();
|
||||
let num_upsamplers = n_blocks - 1;
|
||||
let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
|
||||
let forward_upsample_size =
|
||||
height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
|
||||
// 0. center input if necessary
|
||||
let xs = if self.config.center_input_sample {
|
||||
((xs * 2.0)? - 1.0)?
|
||||
} else {
|
||||
xs.clone()
|
||||
};
|
||||
// 1. time
|
||||
let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?;
|
||||
let emb = self.time_proj.forward(&emb)?;
|
||||
let emb = self.time_embedding.forward(&emb)?;
|
||||
// 2. pre-process
|
||||
let xs = self.conv_in.forward(&xs)?;
|
||||
// 3. down
|
||||
let mut down_block_res_xs = vec![xs.clone()];
|
||||
let mut xs = xs;
|
||||
for down_block in self.down_blocks.iter() {
|
||||
let (_xs, res_xs) = match down_block {
|
||||
UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
|
||||
UNetDownBlock::CrossAttn(b) => {
|
||||
b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
|
||||
}
|
||||
};
|
||||
down_block_res_xs.extend(res_xs);
|
||||
xs = _xs;
|
||||
}
|
||||
|
||||
let new_down_block_res_xs =
|
||||
if let Some(down_block_additional_residuals) = down_block_additional_residuals {
|
||||
let mut v = vec![];
|
||||
// A previous version of this code had a bug because of the addition being made
|
||||
// in place via += hence modifying the input of the mid block.
|
||||
for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
|
||||
v.push((&down_block_res_xs[i] + residuals)?)
|
||||
}
|
||||
v
|
||||
} else {
|
||||
down_block_res_xs
|
||||
};
|
||||
let mut down_block_res_xs = new_down_block_res_xs;
|
||||
|
||||
// 4. mid
|
||||
let xs = self
|
||||
.mid_block
|
||||
.forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
|
||||
let xs = match mid_block_additional_residual {
|
||||
None => xs,
|
||||
Some(m) => (m + xs)?,
|
||||
};
|
||||
// 5. up
|
||||
let mut xs = xs;
|
||||
let mut upsample_size = None;
|
||||
for (i, up_block) in self.up_blocks.iter().enumerate() {
|
||||
let n_resnets = match up_block {
|
||||
UNetUpBlock::Basic(b) => b.resnets.len(),
|
||||
UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
|
||||
};
|
||||
let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
|
||||
if i < n_blocks - 1 && forward_upsample_size {
|
||||
let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
|
||||
upsample_size = Some((h, w))
|
||||
}
|
||||
xs = match up_block {
|
||||
UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
|
||||
UNetUpBlock::CrossAttn(b) => b.forward(
|
||||
&xs,
|
||||
&res_xs,
|
||||
Some(&emb),
|
||||
upsample_size,
|
||||
Some(encoder_hidden_states),
|
||||
)?,
|
||||
};
|
||||
}
|
||||
// 6. post-process
|
||||
let xs = self.conv_norm_out.forward(&xs)?;
|
||||
let xs = nn::ops::silu(&xs)?;
|
||||
self.conv_out.forward(&xs)
|
||||
}
|
||||
}
|
809
candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
Normal file
809
candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
Normal file
@ -0,0 +1,809 @@
|
||||
#![allow(dead_code)]
|
||||
//! 2D UNet Building Blocks
|
||||
//!
|
||||
use crate::attention::{
|
||||
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
|
||||
};
|
||||
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Downsample2D {
|
||||
conv: Option<nn::Conv2d>,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
impl Downsample2D {
|
||||
fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
use_conv: bool,
|
||||
out_channels: usize,
|
||||
padding: usize,
|
||||
) -> Result<Self> {
|
||||
let conv = if use_conv {
|
||||
let config = nn::Conv2dConfig { stride: 2, padding };
|
||||
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
||||
Some(conv)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Downsample2D { conv, padding })
|
||||
}
|
||||
}
|
||||
|
||||
impl Downsample2D {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match &self.conv {
|
||||
None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
|
||||
Some(conv) => {
|
||||
if self.padding == 0 {
|
||||
let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?;
|
||||
conv.forward(&xs)
|
||||
} else {
|
||||
conv.forward(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This does not support the conv-transpose mode.
|
||||
#[derive(Debug)]
|
||||
struct Upsample2D {
|
||||
conv: nn::Conv2d,
|
||||
}
|
||||
|
||||
impl Upsample2D {
|
||||
fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
|
||||
let config = nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
||||
Ok(Self { conv })
|
||||
}
|
||||
}
|
||||
|
||||
impl Upsample2D {
|
||||
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
|
||||
let xs = match size {
|
||||
None => {
|
||||
// The following does not work and it's tricky to pass no fixed
|
||||
// dimensions so hack our way around this.
|
||||
// xs.upsample_nearest2d(&[], Some(2.), Some(2.)
|
||||
let (_bsize, _channels, _h, _w) = xs.dims4()?;
|
||||
crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.))
|
||||
}
|
||||
Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None),
|
||||
};
|
||||
self.conv.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct DownEncoderBlock2DConfig {
|
||||
pub num_layers: usize,
|
||||
pub resnet_eps: f64,
|
||||
pub resnet_groups: usize,
|
||||
pub output_scale_factor: f64,
|
||||
pub add_downsample: bool,
|
||||
pub downsample_padding: usize,
|
||||
}
|
||||
|
||||
impl Default for DownEncoderBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: 32,
|
||||
output_scale_factor: 1.,
|
||||
add_downsample: true,
|
||||
downsample_padding: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DownEncoderBlock2D {
|
||||
resnets: Vec<ResnetBlock2D>,
|
||||
downsampler: Option<Downsample2D>,
|
||||
pub config: DownEncoderBlock2DConfig,
|
||||
}
|
||||
|
||||
impl DownEncoderBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
config: DownEncoderBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let resnets: Vec<_> = {
|
||||
let vs = vs.pp("resnets");
|
||||
let conv_cfg = ResnetBlock2DConfig {
|
||||
eps: config.resnet_eps,
|
||||
out_channels: Some(out_channels),
|
||||
groups: config.resnet_groups,
|
||||
output_scale_factor: config.output_scale_factor,
|
||||
temb_channels: None,
|
||||
..Default::default()
|
||||
};
|
||||
(0..(config.num_layers))
|
||||
.map(|i| {
|
||||
let in_channels = if i == 0 { in_channels } else { out_channels };
|
||||
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
};
|
||||
let downsampler = if config.add_downsample {
|
||||
let downsample = Downsample2D::new(
|
||||
vs.pp("downsamplers").pp("0"),
|
||||
out_channels,
|
||||
true,
|
||||
out_channels,
|
||||
config.downsample_padding,
|
||||
)?;
|
||||
Some(downsample)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
resnets,
|
||||
downsampler,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl DownEncoderBlock2D {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for resnet in self.resnets.iter() {
|
||||
xs = resnet.forward(&xs, None)?
|
||||
}
|
||||
match &self.downsampler {
|
||||
Some(downsampler) => downsampler.forward(&xs),
|
||||
None => Ok(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct UpDecoderBlock2DConfig {
|
||||
pub num_layers: usize,
|
||||
pub resnet_eps: f64,
|
||||
pub resnet_groups: usize,
|
||||
pub output_scale_factor: f64,
|
||||
pub add_upsample: bool,
|
||||
}
|
||||
|
||||
impl Default for UpDecoderBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: 32,
|
||||
output_scale_factor: 1.,
|
||||
add_upsample: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UpDecoderBlock2D {
|
||||
resnets: Vec<ResnetBlock2D>,
|
||||
upsampler: Option<Upsample2D>,
|
||||
pub config: UpDecoderBlock2DConfig,
|
||||
}
|
||||
|
||||
impl UpDecoderBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
config: UpDecoderBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let resnets: Vec<_> = {
|
||||
let vs = vs.pp("resnets");
|
||||
let conv_cfg = ResnetBlock2DConfig {
|
||||
out_channels: Some(out_channels),
|
||||
eps: config.resnet_eps,
|
||||
groups: config.resnet_groups,
|
||||
output_scale_factor: config.output_scale_factor,
|
||||
temb_channels: None,
|
||||
..Default::default()
|
||||
};
|
||||
(0..(config.num_layers))
|
||||
.map(|i| {
|
||||
let in_channels = if i == 0 { in_channels } else { out_channels };
|
||||
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
};
|
||||
let upsampler = if config.add_upsample {
|
||||
let upsample =
|
||||
Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
|
||||
Some(upsample)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
resnets,
|
||||
upsampler,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl UpDecoderBlock2D {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for resnet in self.resnets.iter() {
|
||||
xs = resnet.forward(&xs, None)?
|
||||
}
|
||||
match &self.upsampler {
|
||||
Some(upsampler) => upsampler.forward(&xs, None),
|
||||
None => Ok(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct UNetMidBlock2DConfig {
|
||||
pub num_layers: usize,
|
||||
pub resnet_eps: f64,
|
||||
pub resnet_groups: Option<usize>,
|
||||
pub attn_num_head_channels: Option<usize>,
|
||||
// attention_type "default"
|
||||
pub output_scale_factor: f64,
|
||||
}
|
||||
|
||||
impl Default for UNetMidBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: Some(32),
|
||||
attn_num_head_channels: Some(1),
|
||||
output_scale_factor: 1.,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UNetMidBlock2D {
|
||||
resnet: ResnetBlock2D,
|
||||
attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
|
||||
pub config: UNetMidBlock2DConfig,
|
||||
}
|
||||
|
||||
impl UNetMidBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
config: UNetMidBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let vs_resnets = vs.pp("resnets");
|
||||
let vs_attns = vs.pp("attentions");
|
||||
let resnet_groups = config
|
||||
.resnet_groups
|
||||
.unwrap_or_else(|| usize::min(in_channels / 4, 32));
|
||||
let resnet_cfg = ResnetBlock2DConfig {
|
||||
eps: config.resnet_eps,
|
||||
groups: resnet_groups,
|
||||
output_scale_factor: config.output_scale_factor,
|
||||
temb_channels,
|
||||
..Default::default()
|
||||
};
|
||||
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
|
||||
let attn_cfg = AttentionBlockConfig {
|
||||
num_head_channels: config.attn_num_head_channels,
|
||||
num_groups: resnet_groups,
|
||||
rescale_output_factor: config.output_scale_factor,
|
||||
eps: config.resnet_eps,
|
||||
};
|
||||
let mut attn_resnets = vec![];
|
||||
for index in 0..config.num_layers {
|
||||
let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
|
||||
let resnet = ResnetBlock2D::new(
|
||||
vs_resnets.pp(&(index + 1).to_string()),
|
||||
in_channels,
|
||||
resnet_cfg,
|
||||
)?;
|
||||
attn_resnets.push((attn, resnet))
|
||||
}
|
||||
Ok(Self {
|
||||
resnet,
|
||||
attn_resnets,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
|
||||
let mut xs = self.resnet.forward(xs, temb)?;
|
||||
for (attn, resnet) in self.attn_resnets.iter() {
|
||||
xs = resnet.forward(&attn.forward(&xs)?, temb)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct UNetMidBlock2DCrossAttnConfig {
|
||||
pub num_layers: usize,
|
||||
pub resnet_eps: f64,
|
||||
pub resnet_groups: Option<usize>,
|
||||
pub attn_num_head_channels: usize,
|
||||
// attention_type "default"
|
||||
pub output_scale_factor: f64,
|
||||
pub cross_attn_dim: usize,
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
}
|
||||
|
||||
impl Default for UNetMidBlock2DCrossAttnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: Some(32),
|
||||
attn_num_head_channels: 1,
|
||||
output_scale_factor: 1.,
|
||||
cross_attn_dim: 1280,
|
||||
sliced_attention_size: None, // Sliced attention disabled
|
||||
use_linear_projection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UNetMidBlock2DCrossAttn {
|
||||
resnet: ResnetBlock2D,
|
||||
attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
|
||||
pub config: UNetMidBlock2DCrossAttnConfig,
|
||||
}
|
||||
|
||||
impl UNetMidBlock2DCrossAttn {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
config: UNetMidBlock2DCrossAttnConfig,
|
||||
) -> Result<Self> {
|
||||
let vs_resnets = vs.pp("resnets");
|
||||
let vs_attns = vs.pp("attentions");
|
||||
let resnet_groups = config
|
||||
.resnet_groups
|
||||
.unwrap_or_else(|| usize::min(in_channels / 4, 32));
|
||||
let resnet_cfg = ResnetBlock2DConfig {
|
||||
eps: config.resnet_eps,
|
||||
groups: resnet_groups,
|
||||
output_scale_factor: config.output_scale_factor,
|
||||
temb_channels,
|
||||
..Default::default()
|
||||
};
|
||||
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
|
||||
let n_heads = config.attn_num_head_channels;
|
||||
let attn_cfg = SpatialTransformerConfig {
|
||||
depth: 1,
|
||||
num_groups: resnet_groups,
|
||||
context_dim: Some(config.cross_attn_dim),
|
||||
sliced_attention_size: config.sliced_attention_size,
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
};
|
||||
let mut attn_resnets = vec![];
|
||||
for index in 0..config.num_layers {
|
||||
let attn = SpatialTransformer::new(
|
||||
vs_attns.pp(&index.to_string()),
|
||||
in_channels,
|
||||
n_heads,
|
||||
in_channels / n_heads,
|
||||
attn_cfg,
|
||||
)?;
|
||||
let resnet = ResnetBlock2D::new(
|
||||
vs_resnets.pp(&(index + 1).to_string()),
|
||||
in_channels,
|
||||
resnet_cfg,
|
||||
)?;
|
||||
attn_resnets.push((attn, resnet))
|
||||
}
|
||||
Ok(Self {
|
||||
resnet,
|
||||
attn_resnets,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
temb: Option<&Tensor>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let mut xs = self.resnet.forward(xs, temb)?;
|
||||
for (attn, resnet) in self.attn_resnets.iter() {
|
||||
xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct DownBlock2DConfig {
|
||||
pub num_layers: usize,
|
||||
pub resnet_eps: f64,
|
||||
// resnet_time_scale_shift: "default"
|
||||
// resnet_act_fn: "swish"
|
||||
pub resnet_groups: usize,
|
||||
pub output_scale_factor: f64,
|
||||
pub add_downsample: bool,
|
||||
pub downsample_padding: usize,
|
||||
}
|
||||
|
||||
impl Default for DownBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: 32,
|
||||
output_scale_factor: 1.,
|
||||
add_downsample: true,
|
||||
downsample_padding: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DownBlock2D {
|
||||
resnets: Vec<ResnetBlock2D>,
|
||||
downsampler: Option<Downsample2D>,
|
||||
pub config: DownBlock2DConfig,
|
||||
}
|
||||
|
||||
impl DownBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
config: DownBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let vs_resnets = vs.pp("resnets");
|
||||
let resnet_cfg = ResnetBlock2DConfig {
|
||||
out_channels: Some(out_channels),
|
||||
eps: config.resnet_eps,
|
||||
output_scale_factor: config.output_scale_factor,
|
||||
temb_channels,
|
||||
..Default::default()
|
||||
};
|
||||
let resnets = (0..config.num_layers)
|
||||
.map(|i| {
|
||||
let in_channels = if i == 0 { in_channels } else { out_channels };
|
||||
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let downsampler = if config.add_downsample {
|
||||
let downsampler = Downsample2D::new(
|
||||
vs.pp("downsamplers").pp("0"),
|
||||
out_channels,
|
||||
true,
|
||||
out_channels,
|
||||
config.downsample_padding,
|
||||
)?;
|
||||
Some(downsampler)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
resnets,
|
||||
downsampler,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
|
||||
let mut xs = xs.clone();
|
||||
let mut output_states = vec![];
|
||||
for resnet in self.resnets.iter() {
|
||||
xs = resnet.forward(&xs, temb)?;
|
||||
output_states.push(xs.clone());
|
||||
}
|
||||
let xs = match &self.downsampler {
|
||||
Some(downsampler) => {
|
||||
let xs = downsampler.forward(&xs)?;
|
||||
output_states.push(xs.clone());
|
||||
xs
|
||||
}
|
||||
None => xs,
|
||||
};
|
||||
Ok((xs, output_states))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct CrossAttnDownBlock2DConfig {
|
||||
pub downblock: DownBlock2DConfig,
|
||||
pub attn_num_head_channels: usize,
|
||||
pub cross_attention_dim: usize,
|
||||
// attention_type: "default"
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
}
|
||||
|
||||
impl Default for CrossAttnDownBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
downblock: Default::default(),
|
||||
attn_num_head_channels: 1,
|
||||
cross_attention_dim: 1280,
|
||||
sliced_attention_size: None,
|
||||
use_linear_projection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CrossAttnDownBlock2D {
|
||||
downblock: DownBlock2D,
|
||||
attentions: Vec<SpatialTransformer>,
|
||||
pub config: CrossAttnDownBlock2DConfig,
|
||||
}
|
||||
|
||||
impl CrossAttnDownBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
config: CrossAttnDownBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let downblock = DownBlock2D::new(
|
||||
vs.clone(),
|
||||
in_channels,
|
||||
out_channels,
|
||||
temb_channels,
|
||||
config.downblock,
|
||||
)?;
|
||||
let n_heads = config.attn_num_head_channels;
|
||||
let cfg = SpatialTransformerConfig {
|
||||
depth: 1,
|
||||
context_dim: Some(config.cross_attention_dim),
|
||||
num_groups: config.downblock.resnet_groups,
|
||||
sliced_attention_size: config.sliced_attention_size,
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
};
|
||||
let vs_attn = vs.pp("attentions");
|
||||
let attentions = (0..config.downblock.num_layers)
|
||||
.map(|i| {
|
||||
SpatialTransformer::new(
|
||||
vs_attn.pp(&i.to_string()),
|
||||
out_channels,
|
||||
n_heads,
|
||||
out_channels / n_heads,
|
||||
cfg,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
downblock,
|
||||
attentions,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
temb: Option<&Tensor>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Vec<Tensor>)> {
|
||||
let mut output_states = vec![];
|
||||
let mut xs = xs.clone();
|
||||
for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
|
||||
xs = resnet.forward(&xs, temb)?;
|
||||
xs = attn.forward(&xs, encoder_hidden_states)?;
|
||||
output_states.push(xs.clone());
|
||||
}
|
||||
let xs = match &self.downblock.downsampler {
|
||||
Some(downsampler) => {
|
||||
let xs = downsampler.forward(&xs)?;
|
||||
output_states.push(xs.clone());
|
||||
xs
|
||||
}
|
||||
None => xs,
|
||||
};
|
||||
Ok((xs, output_states))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct UpBlock2DConfig {
|
||||
pub num_layers: usize,
|
||||
pub resnet_eps: f64,
|
||||
// resnet_time_scale_shift: "default"
|
||||
// resnet_act_fn: "swish"
|
||||
pub resnet_groups: usize,
|
||||
pub output_scale_factor: f64,
|
||||
pub add_upsample: bool,
|
||||
}
|
||||
|
||||
impl Default for UpBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_layers: 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: 32,
|
||||
output_scale_factor: 1.,
|
||||
add_upsample: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UpBlock2D {
|
||||
pub resnets: Vec<ResnetBlock2D>,
|
||||
upsampler: Option<Upsample2D>,
|
||||
pub config: UpBlock2DConfig,
|
||||
}
|
||||
|
||||
impl UpBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
prev_output_channels: usize,
|
||||
out_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
config: UpBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let vs_resnets = vs.pp("resnets");
|
||||
let resnet_cfg = ResnetBlock2DConfig {
|
||||
out_channels: Some(out_channels),
|
||||
temb_channels,
|
||||
eps: config.resnet_eps,
|
||||
output_scale_factor: config.output_scale_factor,
|
||||
..Default::default()
|
||||
};
|
||||
let resnets = (0..config.num_layers)
|
||||
.map(|i| {
|
||||
let res_skip_channels = if i == config.num_layers - 1 {
|
||||
in_channels
|
||||
} else {
|
||||
out_channels
|
||||
};
|
||||
let resnet_in_channels = if i == 0 {
|
||||
prev_output_channels
|
||||
} else {
|
||||
out_channels
|
||||
};
|
||||
let in_channels = resnet_in_channels + res_skip_channels;
|
||||
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let upsampler = if config.add_upsample {
|
||||
let upsampler =
|
||||
Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
|
||||
Some(upsampler)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
resnets,
|
||||
upsampler,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
res_xs: &[Tensor],
|
||||
temb: Option<&Tensor>,
|
||||
upsample_size: Option<(usize, usize)>,
|
||||
) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for (index, resnet) in self.resnets.iter().enumerate() {
|
||||
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
||||
xs = resnet.forward(&xs, temb)?;
|
||||
}
|
||||
match &self.upsampler {
|
||||
Some(upsampler) => upsampler.forward(&xs, upsample_size),
|
||||
None => Ok(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct CrossAttnUpBlock2DConfig {
|
||||
pub upblock: UpBlock2DConfig,
|
||||
pub attn_num_head_channels: usize,
|
||||
pub cross_attention_dim: usize,
|
||||
// attention_type: "default"
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
}
|
||||
|
||||
impl Default for CrossAttnUpBlock2DConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
upblock: Default::default(),
|
||||
attn_num_head_channels: 1,
|
||||
cross_attention_dim: 1280,
|
||||
sliced_attention_size: None,
|
||||
use_linear_projection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CrossAttnUpBlock2D {
|
||||
pub upblock: UpBlock2D,
|
||||
pub attentions: Vec<SpatialTransformer>,
|
||||
pub config: CrossAttnUpBlock2DConfig,
|
||||
}
|
||||
|
||||
impl CrossAttnUpBlock2D {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
prev_output_channels: usize,
|
||||
out_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
config: CrossAttnUpBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let upblock = UpBlock2D::new(
|
||||
vs.clone(),
|
||||
in_channels,
|
||||
prev_output_channels,
|
||||
out_channels,
|
||||
temb_channels,
|
||||
config.upblock,
|
||||
)?;
|
||||
let n_heads = config.attn_num_head_channels;
|
||||
let cfg = SpatialTransformerConfig {
|
||||
depth: 1,
|
||||
context_dim: Some(config.cross_attention_dim),
|
||||
num_groups: config.upblock.resnet_groups,
|
||||
sliced_attention_size: config.sliced_attention_size,
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
};
|
||||
let vs_attn = vs.pp("attentions");
|
||||
let attentions = (0..config.upblock.num_layers)
|
||||
.map(|i| {
|
||||
SpatialTransformer::new(
|
||||
vs_attn.pp(&i.to_string()),
|
||||
out_channels,
|
||||
n_heads,
|
||||
out_channels / n_heads,
|
||||
cfg,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
upblock,
|
||||
attentions,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
res_xs: &[Tensor],
|
||||
temb: Option<&Tensor>,
|
||||
upsample_size: Option<(usize, usize)>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
|
||||
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
||||
xs = resnet.forward(&xs, temb)?;
|
||||
xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
|
||||
}
|
||||
match &self.upblock.upsampler {
|
||||
Some(upsampler) => upsampler.forward(&xs, upsample_size),
|
||||
None => Ok(xs),
|
||||
}
|
||||
}
|
||||
}
|
17
candle-examples/examples/stable-diffusion/utils.rs
Normal file
17
candle-examples/examples/stable-diffusion/utils.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
pub fn sigmoid(_: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn pad(_: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
378
candle-examples/examples/stable-diffusion/vae.rs
Normal file
378
candle-examples/examples/stable-diffusion/vae.rs
Normal file
@ -0,0 +1,378 @@
|
||||
#![allow(dead_code)]
|
||||
//! # Variational Auto-Encoder (VAE) Models.
|
||||
//!
|
||||
//! Auto-encoder models compress their input to a usually smaller latent space
|
||||
//! before expanding it back to its original shape. This results in the latent values
|
||||
//! compressing the original information.
|
||||
use crate::unet_2d_blocks::{
|
||||
DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
|
||||
UpDecoderBlock2D, UpDecoderBlock2DConfig,
|
||||
};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct EncoderConfig {
|
||||
// down_block_types: DownEncoderBlock2D
|
||||
block_out_channels: Vec<usize>,
|
||||
layers_per_block: usize,
|
||||
norm_num_groups: usize,
|
||||
double_z: bool,
|
||||
}
|
||||
|
||||
impl Default for EncoderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
block_out_channels: vec![64],
|
||||
layers_per_block: 2,
|
||||
norm_num_groups: 32,
|
||||
double_z: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Encoder {
|
||||
conv_in: nn::Conv2d,
|
||||
down_blocks: Vec<DownEncoderBlock2D>,
|
||||
mid_block: UNetMidBlock2D,
|
||||
conv_norm_out: nn::GroupNorm,
|
||||
conv_out: nn::Conv2d,
|
||||
#[allow(dead_code)]
|
||||
config: EncoderConfig,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
config: EncoderConfig,
|
||||
) -> Result<Self> {
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
};
|
||||
let conv_in = nn::conv2d(
|
||||
in_channels,
|
||||
config.block_out_channels[0],
|
||||
3,
|
||||
conv_cfg,
|
||||
vs.pp("conv_in"),
|
||||
)?;
|
||||
let mut down_blocks = vec![];
|
||||
let vs_down_blocks = vs.pp("down_blocks");
|
||||
for index in 0..config.block_out_channels.len() {
|
||||
let out_channels = config.block_out_channels[index];
|
||||
let in_channels = if index > 0 {
|
||||
config.block_out_channels[index - 1]
|
||||
} else {
|
||||
config.block_out_channels[0]
|
||||
};
|
||||
let is_final = index + 1 == config.block_out_channels.len();
|
||||
let cfg = DownEncoderBlock2DConfig {
|
||||
num_layers: config.layers_per_block,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: config.norm_num_groups,
|
||||
add_downsample: !is_final,
|
||||
downsample_padding: 0,
|
||||
..Default::default()
|
||||
};
|
||||
let down_block = DownEncoderBlock2D::new(
|
||||
vs_down_blocks.pp(&index.to_string()),
|
||||
in_channels,
|
||||
out_channels,
|
||||
cfg,
|
||||
)?;
|
||||
down_blocks.push(down_block)
|
||||
}
|
||||
let last_block_out_channels = *config.block_out_channels.last().unwrap();
|
||||
let mid_cfg = UNetMidBlock2DConfig {
|
||||
resnet_eps: 1e-6,
|
||||
output_scale_factor: 1.,
|
||||
attn_num_head_channels: None,
|
||||
resnet_groups: Some(config.norm_num_groups),
|
||||
..Default::default()
|
||||
};
|
||||
let mid_block =
|
||||
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
|
||||
let conv_norm_out = nn::group_norm(
|
||||
config.norm_num_groups,
|
||||
last_block_out_channels,
|
||||
1e-6,
|
||||
vs.pp("conv_norm_out"),
|
||||
)?;
|
||||
let conv_out_channels = if config.double_z {
|
||||
2 * out_channels
|
||||
} else {
|
||||
out_channels
|
||||
};
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv_out = nn::conv2d(
|
||||
last_block_out_channels,
|
||||
conv_out_channels,
|
||||
3,
|
||||
conv_cfg,
|
||||
vs.pp("conv_out"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
conv_in,
|
||||
down_blocks,
|
||||
mid_block,
|
||||
conv_norm_out,
|
||||
conv_out,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.conv_in.forward(xs)?;
|
||||
for down_block in self.down_blocks.iter() {
|
||||
xs = down_block.forward(&xs)?
|
||||
}
|
||||
let xs = self.mid_block.forward(&xs, None)?;
|
||||
let xs = self.conv_norm_out.forward(&xs)?;
|
||||
let xs = nn::ops::silu(&xs)?;
|
||||
self.conv_out.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderConfig {
|
||||
// up_block_types: UpDecoderBlock2D
|
||||
block_out_channels: Vec<usize>,
|
||||
layers_per_block: usize,
|
||||
norm_num_groups: usize,
|
||||
}
|
||||
|
||||
impl Default for DecoderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
block_out_channels: vec![64],
|
||||
layers_per_block: 2,
|
||||
norm_num_groups: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Decoder {
|
||||
conv_in: nn::Conv2d,
|
||||
up_blocks: Vec<UpDecoderBlock2D>,
|
||||
mid_block: UNetMidBlock2D,
|
||||
conv_norm_out: nn::GroupNorm,
|
||||
conv_out: nn::Conv2d,
|
||||
#[allow(dead_code)]
|
||||
config: DecoderConfig,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
config: DecoderConfig,
|
||||
) -> Result<Self> {
|
||||
let n_block_out_channels = config.block_out_channels.len();
|
||||
let last_block_out_channels = *config.block_out_channels.last().unwrap();
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
};
|
||||
let conv_in = nn::conv2d(
|
||||
in_channels,
|
||||
last_block_out_channels,
|
||||
3,
|
||||
conv_cfg,
|
||||
vs.pp("conv_in"),
|
||||
)?;
|
||||
let mid_cfg = UNetMidBlock2DConfig {
|
||||
resnet_eps: 1e-6,
|
||||
output_scale_factor: 1.,
|
||||
attn_num_head_channels: None,
|
||||
resnet_groups: Some(config.norm_num_groups),
|
||||
..Default::default()
|
||||
};
|
||||
let mid_block =
|
||||
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
|
||||
let mut up_blocks = vec![];
|
||||
let vs_up_blocks = vs.pp("up_blocks");
|
||||
let reversed_block_out_channels: Vec<_> =
|
||||
config.block_out_channels.iter().copied().rev().collect();
|
||||
for index in 0..n_block_out_channels {
|
||||
let out_channels = reversed_block_out_channels[index];
|
||||
let in_channels = if index > 0 {
|
||||
reversed_block_out_channels[index - 1]
|
||||
} else {
|
||||
reversed_block_out_channels[0]
|
||||
};
|
||||
let is_final = index + 1 == n_block_out_channels;
|
||||
let cfg = UpDecoderBlock2DConfig {
|
||||
num_layers: config.layers_per_block + 1,
|
||||
resnet_eps: 1e-6,
|
||||
resnet_groups: config.norm_num_groups,
|
||||
add_upsample: !is_final,
|
||||
..Default::default()
|
||||
};
|
||||
let up_block = UpDecoderBlock2D::new(
|
||||
vs_up_blocks.pp(&index.to_string()),
|
||||
in_channels,
|
||||
out_channels,
|
||||
cfg,
|
||||
)?;
|
||||
up_blocks.push(up_block)
|
||||
}
|
||||
let conv_norm_out = nn::group_norm(
|
||||
config.norm_num_groups,
|
||||
config.block_out_channels[0],
|
||||
1e-6,
|
||||
vs.pp("conv_norm_out"),
|
||||
)?;
|
||||
let conv_cfg = nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv_out = nn::conv2d(
|
||||
config.block_out_channels[0],
|
||||
out_channels,
|
||||
3,
|
||||
conv_cfg,
|
||||
vs.pp("conv_out"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
conv_in,
|
||||
up_blocks,
|
||||
mid_block,
|
||||
conv_norm_out,
|
||||
conv_out,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?;
|
||||
for up_block in self.up_blocks.iter() {
|
||||
xs = up_block.forward(&xs)?
|
||||
}
|
||||
let xs = self.conv_norm_out.forward(&xs)?;
|
||||
let xs = nn::ops::silu(&xs)?;
|
||||
self.conv_out.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutoEncoderKLConfig {
|
||||
pub block_out_channels: Vec<usize>,
|
||||
pub layers_per_block: usize,
|
||||
pub latent_channels: usize,
|
||||
pub norm_num_groups: usize,
|
||||
}
|
||||
|
||||
impl Default for AutoEncoderKLConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
block_out_channels: vec![64],
|
||||
layers_per_block: 1,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DiagonalGaussianDistribution {
|
||||
mean: Tensor,
|
||||
std: Tensor,
|
||||
}
|
||||
|
||||
impl DiagonalGaussianDistribution {
|
||||
pub fn new(parameters: &Tensor) -> Result<Self> {
|
||||
let mut parameters = parameters.chunk(2, 1)?.into_iter();
|
||||
let mean = parameters.next().unwrap();
|
||||
let logvar = parameters.next().unwrap();
|
||||
let std = (logvar * 0.5)?.exp()?;
|
||||
Ok(DiagonalGaussianDistribution { mean, std })
|
||||
}
|
||||
|
||||
pub fn sample(&self) -> Result<Tensor> {
|
||||
let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device());
|
||||
&self.mean + &self.std * sample
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485
|
||||
// This implementation is specific to the config used in stable-diffusion-v1-5
|
||||
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
|
||||
#[derive(Debug)]
|
||||
pub struct AutoEncoderKL {
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
quant_conv: nn::Conv2d,
|
||||
post_quant_conv: nn::Conv2d,
|
||||
pub config: AutoEncoderKLConfig,
|
||||
}
|
||||
|
||||
impl AutoEncoderKL {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
config: AutoEncoderKLConfig,
|
||||
) -> Result<Self> {
|
||||
let latent_channels = config.latent_channels;
|
||||
let encoder_cfg = EncoderConfig {
|
||||
block_out_channels: config.block_out_channels.clone(),
|
||||
layers_per_block: config.layers_per_block,
|
||||
norm_num_groups: config.norm_num_groups,
|
||||
double_z: true,
|
||||
};
|
||||
let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?;
|
||||
let decoder_cfg = DecoderConfig {
|
||||
block_out_channels: config.block_out_channels.clone(),
|
||||
layers_per_block: config.layers_per_block,
|
||||
norm_num_groups: config.norm_num_groups,
|
||||
};
|
||||
let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
|
||||
let conv_cfg = Default::default();
|
||||
let quant_conv = nn::conv2d(
|
||||
2 * latent_channels,
|
||||
2 * latent_channels,
|
||||
1,
|
||||
conv_cfg,
|
||||
vs.pp("quant_conv"),
|
||||
)?;
|
||||
let post_quant_conv = nn::conv2d(
|
||||
latent_channels,
|
||||
latent_channels,
|
||||
1,
|
||||
conv_cfg,
|
||||
vs.pp("post_quant_conv"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
quant_conv,
|
||||
post_quant_conv,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the distribution in the latent space.
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
|
||||
let xs = self.encoder.forward(xs)?;
|
||||
let parameters = self.quant_conv.forward(&xs)?;
|
||||
DiagonalGaussianDistribution::new(¶meters)
|
||||
}
|
||||
|
||||
/// Takes as input some sampled values.
|
||||
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.post_quant_conv.forward(xs)?;
|
||||
self.decoder.forward(&xs)
|
||||
}
|
||||
}
|
@ -48,3 +48,84 @@ impl Conv1d {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct Conv2dConfig {
|
||||
pub padding: usize,
|
||||
pub stride: usize,
|
||||
}
|
||||
|
||||
impl Default for Conv2dConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug)]
|
||||
pub struct Conv2d {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
config: Conv2dConfig,
|
||||
}
|
||||
|
||||
impl Conv2d {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>, config: Conv2dConfig) -> Self {
|
||||
Self {
|
||||
weight,
|
||||
bias,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &Conv2dConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn forward(&self, _x: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conv1d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: Conv1dConfig,
|
||||
vs: crate::VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_or_init((out_channels, in_channels, kernel_size), "weight", init_ws)?;
|
||||
let bound = 1. / (in_channels as f64).sqrt();
|
||||
let init_bs = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vs.get_or_init(out_channels, "bias", init_bs)?;
|
||||
Ok(Conv1d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
pub fn conv2d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: Conv2dConfig,
|
||||
vs: crate::VarBuilder,
|
||||
) -> Result<Conv2d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_or_init(
|
||||
(out_channels, in_channels, kernel_size, kernel_size),
|
||||
"weight",
|
||||
init_ws,
|
||||
)?;
|
||||
let bound = 1. / (in_channels as f64).sqrt();
|
||||
let init_bs = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vs.get_or_init(out_channels, "bias", init_bs)?;
|
||||
Ok(Conv2d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
48
candle-nn/src/group_norm.rs
Normal file
48
candle-nn/src/group_norm.rs
Normal file
@ -0,0 +1,48 @@
|
||||
//! Group Normalization.
|
||||
//!
|
||||
//! This layer applies Group Normalization over a mini-batch of inputs.
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
// This group norm version handles both weight and bias so removes the mean.
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug)]
|
||||
pub struct GroupNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
num_channels: usize,
|
||||
num_groups: usize,
|
||||
}
|
||||
|
||||
impl GroupNorm {
|
||||
pub fn new(
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
num_channels: usize,
|
||||
num_groups: usize,
|
||||
eps: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
num_channels,
|
||||
num_groups,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, _: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn group_norm(
|
||||
num_channels: usize,
|
||||
num_groups: usize,
|
||||
eps: f64,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<GroupNorm> {
|
||||
let weight = vb.get_or_init(num_channels, "weight", crate::Init::Const(1.))?;
|
||||
let bias = vb.get_or_init(num_channels, "bias", crate::Init::Const(0.))?;
|
||||
Ok(GroupNorm::new(weight, bias, num_channels, num_groups, eps))
|
||||
}
|
@ -3,6 +3,7 @@
|
||||
pub mod activation;
|
||||
pub mod conv;
|
||||
pub mod embedding;
|
||||
pub mod group_norm;
|
||||
pub mod init;
|
||||
pub mod layer_norm;
|
||||
pub mod linear;
|
||||
@ -12,8 +13,9 @@ pub mod optim;
|
||||
pub mod var_builder;
|
||||
|
||||
pub use activation::Activation;
|
||||
pub use conv::{Conv1d, Conv1dConfig};
|
||||
pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
|
||||
pub use embedding::{embedding, Embedding};
|
||||
pub use group_norm::{group_norm, GroupNorm};
|
||||
pub use init::Init;
|
||||
pub use layer_norm::{layer_norm, LayerNorm};
|
||||
pub use linear::{linear, linear_no_bias, Linear};
|
||||
|
@ -32,3 +32,7 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
||||
let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
|
||||
Ok(log_sm)
|
||||
}
|
||||
|
||||
pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
Reference in New Issue
Block a user