diff --git a/candle-transformers/src/models/wuerstchen/attention_processor.rs b/candle-transformers/src/models/wuerstchen/attention_processor.rs new file mode 100644 index 00000000..3f1a72eb --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/attention_processor.rs @@ -0,0 +1,74 @@ +use candle::{Module, Result, Tensor}; +use candle_nn::{linear, Linear, VarBuilder}; + +// A simplified version of: +// https://github.com/huggingface/diffusers/blob/119ad2c3dc8a8fb8446a83f4bf6f20929487b47f/src/diffusers/models/attention_processor.py#L38 +#[derive(Debug)] +pub struct Attention { + to_q: Linear, + to_k: Linear, + to_v: Linear, + to_out: Linear, + heads: usize, + scale: f64, +} + +impl Attention { + pub fn new(query_dim: usize, heads: usize, dim_head: usize, vb: VarBuilder) -> Result { + let inner_dim = dim_head * heads; + let scale = 1.0 / f64::sqrt(dim_head as f64); + let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?; + let to_k = linear(query_dim, inner_dim, vb.pp("to_k"))?; + let to_v = linear(query_dim, inner_dim, vb.pp("to_v"))?; + let to_out = linear(inner_dim, query_dim, vb.pp("to_out.0"))?; + Ok(Self { + to_q, + to_k, + to_v, + to_out, + scale, + heads, + }) + } + + fn batch_to_head_dim(&self, xs: &Tensor) -> Result { + let (b_size, seq_len, dim) = xs.dims3()?; + xs.reshape((b_size / self.heads, self.heads, seq_len, dim))? + .permute((0, 2, 1, 3))? + .reshape((b_size / self.heads, seq_len, dim * self.heads)) + } + + fn head_to_batch_dim(&self, xs: &Tensor) -> Result { + let (b_size, seq_len, dim) = xs.dims3()?; + xs.reshape((b_size, seq_len, self.heads, dim / self.heads))? + .permute((0, 2, 1, 3))? + .reshape((b_size * self.heads, seq_len, dim / self.heads)) + } + + fn get_attention_scores(&self, query: &Tensor, key: &Tensor) -> Result { + let attn_probs = (query.matmul(&key.t()?)? * self.scale)?; + candle_nn::ops::softmax_last_dim(&attn_probs) + } + + pub fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result { + let (b_size, channel, h, w) = xs.dims4()?; + let xs = xs.reshape((b_size, channel, h * w))?.t()?; + + let query = self.to_q.forward(&xs)?; + let key = self.to_k.forward(encoder_hidden_states)?; + let value = self.to_v.forward(encoder_hidden_states)?; + + let query = self.head_to_batch_dim(&query)?; + let key = self.head_to_batch_dim(&key)?; + let value = self.head_to_batch_dim(&value)?; + + let attn_prs = self.get_attention_scores(&query, &key)?; + let xs = attn_prs.matmul(&value)?; + let xs = self.batch_to_head_dim(&xs)?; + + self.to_out + .forward(&xs)? + .t()? + .reshape((b_size, channel, h, w)) + } +} diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs index 5337fdc6..1eb0c2e7 100644 --- a/candle-transformers/src/models/wuerstchen/common.rs +++ b/candle-transformers/src/models/wuerstchen/common.rs @@ -131,7 +131,7 @@ impl ResBlock { xs + x_res } } -use crate::models::stable_diffusion::attention::CrossAttention as Attention; +use super::attention_processor::Attention; #[derive(Debug)] pub struct AttnBlock { self_attn: bool, @@ -149,7 +149,7 @@ impl AttnBlock { vb: VarBuilder, ) -> Result { let norm = WLayerNorm::new(c)?; - let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?; + let attention = Attention::new(c, nhead, c / nhead, vb.pp("attention"))?; let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?; Ok(Self { self_attn, @@ -165,10 +165,10 @@ impl AttnBlock { let kv = if self.self_attn { let (b_size, channel, _, _) = xs.dims4()?; let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?; - Tensor::cat(&[&norm_xs, &kv], 1)? + Tensor::cat(&[&norm_xs, &kv], 1)?.contiguous()? } else { kv }; - xs + self.attention.forward(&norm_xs, Some(&kv)) + xs + self.attention.forward(&norm_xs, &kv) } } diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs index f499bc35..7b076f06 100644 --- a/candle-transformers/src/models/wuerstchen/mod.rs +++ b/candle-transformers/src/models/wuerstchen/mod.rs @@ -1,3 +1,4 @@ +pub mod attention_processor; pub mod common; pub mod ddpm; pub mod diffnext; diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs index 93385a32..168b70a6 100644 --- a/candle-transformers/src/models/wuerstchen/prior.rs +++ b/candle-transformers/src/models/wuerstchen/prior.rs @@ -94,7 +94,7 @@ impl WPrior { xs = block.ts_block.forward(&xs, &r_embed)?; xs = block.attn_block.forward(&xs, &c_embed)?; } - let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(1, 2)?; + let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(2, 1)?; (x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5) } }