mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add a stable diffusion example (#328)
* Start adding a stable-diffusion example. * Proper computation of the causal mask. * Add the chunk operation. * Work in progress: port the attention module. * Add some dummy modules for conv2d and group-norm, get the attention module to compile. * Re-enable the 2d convolution. * Add the embeddings module. * Add the resnet module. * Add the unet blocks. * Add the unet. * And add the variational auto-encoder. * Use the pad function from utils.
This commit is contained in:
445
candle-examples/examples/stable-diffusion/attention.rs
Normal file
445
candle-examples/examples/stable-diffusion/attention.rs
Normal file
@ -0,0 +1,445 @@
|
||||
#![allow(dead_code)]
|
||||
//! Attention Based Building Blocks
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct GeGlu {
|
||||
proj: nn::Linear,
|
||||
}
|
||||
|
||||
impl GeGlu {
|
||||
fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
|
||||
let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
|
||||
Ok(Self { proj })
|
||||
}
|
||||
}
|
||||
|
||||
impl GeGlu {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
|
||||
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
|
||||
}
|
||||
}
|
||||
|
||||
/// A feed-forward layer.
|
||||
#[derive(Debug)]
|
||||
struct FeedForward {
|
||||
project_in: GeGlu,
|
||||
linear: nn::Linear,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
// The glu parameter in the python code is unused?
|
||||
// https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347
|
||||
/// Creates a new feed-forward layer based on some given input dimension, some
|
||||
/// output dimension, and a multiplier to be used for the intermediary layer.
|
||||
fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
|
||||
let inner_dim = dim * mult;
|
||||
let dim_out = dim_out.unwrap_or(dim);
|
||||
let vs = vs.pp("net");
|
||||
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
|
||||
let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
|
||||
Ok(Self { project_in, linear })
|
||||
}
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.project_in.forward(xs)?;
|
||||
self.linear.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CrossAttention {
|
||||
to_q: nn::Linear,
|
||||
to_k: nn::Linear,
|
||||
to_v: nn::Linear,
|
||||
to_out: nn::Linear,
|
||||
heads: usize,
|
||||
scale: f64,
|
||||
slice_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl CrossAttention {
|
||||
// Defaults should be heads = 8, dim_head = 64, context_dim = None
|
||||
fn new(
|
||||
vs: nn::VarBuilder,
|
||||
query_dim: usize,
|
||||
context_dim: Option<usize>,
|
||||
heads: usize,
|
||||
dim_head: usize,
|
||||
slice_size: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = dim_head * heads;
|
||||
let context_dim = context_dim.unwrap_or(query_dim);
|
||||
let scale = 1.0 / f64::sqrt(dim_head as f64);
|
||||
let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
|
||||
let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
|
||||
let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
|
||||
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
|
||||
Ok(Self {
|
||||
to_q,
|
||||
to_k,
|
||||
to_v,
|
||||
to_out,
|
||||
heads,
|
||||
scale,
|
||||
slice_size,
|
||||
})
|
||||
}
|
||||
|
||||
fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (batch_size, seq_len, dim) = xs.dims3()?;
|
||||
xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((batch_size * self.heads, seq_len, dim / self.heads))
|
||||
}
|
||||
|
||||
fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (batch_size, seq_len, dim) = xs.dims3()?;
|
||||
xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((batch_size / self.heads, seq_len, dim * self.heads))
|
||||
}
|
||||
|
||||
fn sliced_attention(
|
||||
&self,
|
||||
query: &Tensor,
|
||||
key: &Tensor,
|
||||
value: &Tensor,
|
||||
slice_size: usize,
|
||||
) -> Result<Tensor> {
|
||||
let batch_size_attention = query.dim(0)?;
|
||||
let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
|
||||
|
||||
for i in 0..batch_size_attention / slice_size {
|
||||
let start_idx = i * slice_size;
|
||||
let end_idx = (i + 1) * slice_size;
|
||||
|
||||
let xs = query
|
||||
.i(start_idx..end_idx)?
|
||||
.matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?;
|
||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
|
||||
hidden_states.push(xs)
|
||||
}
|
||||
let hidden_states = Tensor::stack(&hidden_states, 0)?;
|
||||
self.reshape_batch_dim_to_heads(&hidden_states)
|
||||
}
|
||||
|
||||
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
|
||||
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
|
||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
|
||||
self.reshape_batch_dim_to_heads(&xs)
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let query = self.to_q.forward(xs)?;
|
||||
let context = context.unwrap_or(xs);
|
||||
let key = self.to_k.forward(context)?;
|
||||
let value = self.to_v.forward(context)?;
|
||||
let query = self.reshape_heads_to_batch_dim(&query)?;
|
||||
let key = self.reshape_heads_to_batch_dim(&key)?;
|
||||
let value = self.reshape_heads_to_batch_dim(&value)?;
|
||||
let xs = match self.slice_size {
|
||||
None => self.attention(&query, &key, &value)?,
|
||||
Some(slice_size) => {
|
||||
if query.dim(0)? / slice_size <= 1 {
|
||||
self.attention(&query, &key, &value)?
|
||||
} else {
|
||||
self.sliced_attention(&query, &key, &value, slice_size)?
|
||||
}
|
||||
}
|
||||
};
|
||||
self.to_out.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
/// A basic Transformer block.
|
||||
#[derive(Debug)]
|
||||
struct BasicTransformerBlock {
|
||||
attn1: CrossAttention,
|
||||
ff: FeedForward,
|
||||
attn2: CrossAttention,
|
||||
norm1: nn::LayerNorm,
|
||||
norm2: nn::LayerNorm,
|
||||
norm3: nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl BasicTransformerBlock {
|
||||
fn new(
|
||||
vs: nn::VarBuilder,
|
||||
dim: usize,
|
||||
n_heads: usize,
|
||||
d_head: usize,
|
||||
context_dim: Option<usize>,
|
||||
sliced_attention_size: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let attn1 = CrossAttention::new(
|
||||
vs.pp("attn1"),
|
||||
dim,
|
||||
None,
|
||||
n_heads,
|
||||
d_head,
|
||||
sliced_attention_size,
|
||||
)?;
|
||||
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
|
||||
let attn2 = CrossAttention::new(
|
||||
vs.pp("attn2"),
|
||||
dim,
|
||||
context_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
sliced_attention_size,
|
||||
)?;
|
||||
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
|
||||
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
|
||||
let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
|
||||
Ok(Self {
|
||||
attn1,
|
||||
ff,
|
||||
attn2,
|
||||
norm1,
|
||||
norm2,
|
||||
norm3,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
|
||||
let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
|
||||
self.ff.forward(&self.norm3.forward(&xs)?)? + xs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct SpatialTransformerConfig {
|
||||
pub depth: usize,
|
||||
pub num_groups: usize,
|
||||
pub context_dim: Option<usize>,
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
}
|
||||
|
||||
impl Default for SpatialTransformerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
depth: 1,
|
||||
num_groups: 32,
|
||||
context_dim: None,
|
||||
sliced_attention_size: None,
|
||||
use_linear_projection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Proj {
|
||||
Conv2d(nn::Conv2d),
|
||||
Linear(nn::Linear),
|
||||
}
|
||||
|
||||
// Aka Transformer2DModel
|
||||
#[derive(Debug)]
|
||||
pub struct SpatialTransformer {
|
||||
norm: nn::GroupNorm,
|
||||
proj_in: Proj,
|
||||
transformer_blocks: Vec<BasicTransformerBlock>,
|
||||
proj_out: Proj,
|
||||
pub config: SpatialTransformerConfig,
|
||||
}
|
||||
|
||||
impl SpatialTransformer {
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
n_heads: usize,
|
||||
d_head: usize,
|
||||
config: SpatialTransformerConfig,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = n_heads * d_head;
|
||||
let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?;
|
||||
let proj_in = if config.use_linear_projection {
|
||||
Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?)
|
||||
} else {
|
||||
Proj::Conv2d(nn::conv2d(
|
||||
in_channels,
|
||||
inner_dim,
|
||||
1,
|
||||
Default::default(),
|
||||
vs.pp("proj_in"),
|
||||
)?)
|
||||
};
|
||||
let mut transformer_blocks = vec![];
|
||||
let vs_tb = vs.pp("transformer_blocks");
|
||||
for index in 0..config.depth {
|
||||
let tb = BasicTransformerBlock::new(
|
||||
vs_tb.pp(&index.to_string()),
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
config.context_dim,
|
||||
config.sliced_attention_size,
|
||||
)?;
|
||||
transformer_blocks.push(tb)
|
||||
}
|
||||
let proj_out = if config.use_linear_projection {
|
||||
Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?)
|
||||
} else {
|
||||
Proj::Conv2d(nn::conv2d(
|
||||
inner_dim,
|
||||
in_channels,
|
||||
1,
|
||||
Default::default(),
|
||||
vs.pp("proj_out"),
|
||||
)?)
|
||||
};
|
||||
Ok(Self {
|
||||
norm,
|
||||
proj_in,
|
||||
transformer_blocks,
|
||||
proj_out,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let (batch, _channel, height, weight) = xs.dims4()?;
|
||||
let residual = xs;
|
||||
let xs = self.norm.forward(xs)?;
|
||||
let (inner_dim, xs) = match &self.proj_in {
|
||||
Proj::Conv2d(p) => {
|
||||
let xs = p.forward(&xs)?;
|
||||
let inner_dim = xs.dim(1)?;
|
||||
let xs = xs
|
||||
.transpose(1, 2)?
|
||||
.t()?
|
||||
.reshape((batch, height * weight, inner_dim))?;
|
||||
(inner_dim, xs)
|
||||
}
|
||||
Proj::Linear(p) => {
|
||||
let inner_dim = xs.dim(1)?;
|
||||
let xs = xs
|
||||
.transpose(1, 2)?
|
||||
.t()?
|
||||
.reshape((batch, height * weight, inner_dim))?;
|
||||
(inner_dim, p.forward(&xs)?)
|
||||
}
|
||||
};
|
||||
let mut xs = xs;
|
||||
for block in self.transformer_blocks.iter() {
|
||||
xs = block.forward(&xs, context)?
|
||||
}
|
||||
let xs = match &self.proj_out {
|
||||
Proj::Conv2d(p) => p.forward(
|
||||
&xs.reshape((batch, height, weight, inner_dim))?
|
||||
.t()?
|
||||
.transpose(1, 2)?,
|
||||
)?,
|
||||
Proj::Linear(p) => p
|
||||
.forward(&xs)?
|
||||
.reshape((batch, height, weight, inner_dim))?
|
||||
.t()?
|
||||
.transpose(1, 2)?,
|
||||
};
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for an attention block.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct AttentionBlockConfig {
|
||||
pub num_head_channels: Option<usize>,
|
||||
pub num_groups: usize,
|
||||
pub rescale_output_factor: f64,
|
||||
pub eps: f64,
|
||||
}
|
||||
|
||||
impl Default for AttentionBlockConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_head_channels: None,
|
||||
num_groups: 32,
|
||||
rescale_output_factor: 1.,
|
||||
eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AttentionBlock {
|
||||
group_norm: nn::GroupNorm,
|
||||
query: nn::Linear,
|
||||
key: nn::Linear,
|
||||
value: nn::Linear,
|
||||
proj_attn: nn::Linear,
|
||||
channels: usize,
|
||||
num_heads: usize,
|
||||
config: AttentionBlockConfig,
|
||||
}
|
||||
|
||||
impl AttentionBlock {
|
||||
pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
|
||||
let num_head_channels = config.num_head_channels.unwrap_or(channels);
|
||||
let num_heads = channels / num_head_channels;
|
||||
let group_norm =
|
||||
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
|
||||
let query = nn::linear(channels, channels, vs.pp("query"))?;
|
||||
let key = nn::linear(channels, channels, vs.pp("key"))?;
|
||||
let value = nn::linear(channels, channels, vs.pp("value"))?;
|
||||
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
|
||||
Ok(Self {
|
||||
group_norm,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
proj_attn,
|
||||
channels,
|
||||
num_heads,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let (batch, t, h_times_d) = xs.dims3()?;
|
||||
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
|
||||
.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
impl AttentionBlock {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let (batch, channel, height, width) = xs.dims4()?;
|
||||
let xs = self
|
||||
.group_norm
|
||||
.forward(xs)?
|
||||
.reshape((batch, channel, height * width))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let query_proj = self.query.forward(&xs)?;
|
||||
let key_proj = self.key.forward(&xs)?;
|
||||
let value_proj = self.value.forward(&xs)?;
|
||||
|
||||
let query_states = self.transpose_for_scores(query_proj)?;
|
||||
let key_states = self.transpose_for_scores(key_proj)?;
|
||||
let value_states = self.transpose_for_scores(value_proj)?;
|
||||
|
||||
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
|
||||
let attention_scores =
|
||||
// TODO: Check that this needs two multiplication by `scale`.
|
||||
(query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
|
||||
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
|
||||
|
||||
let xs = attention_probs.matmul(&value_states)?;
|
||||
let xs = xs.transpose(1, 2)?.contiguous()?;
|
||||
let xs = xs.flatten_from(D::Minus2)?;
|
||||
let xs = self
|
||||
.proj_attn
|
||||
.forward(&xs)?
|
||||
.t()?
|
||||
.reshape((batch, channel, height, width))?;
|
||||
(xs + residual)? / self.config.rescale_output_factor
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user