mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Specialized attention module for Wuerstchen. (#890)
* Specialized attention module for Wuerstchen. * Reshaping ops. * Attention processor. * Finish the forward pass. * Hook the new attention processor. * Get the prior forward pass to work. * Make it contiguous.
This commit is contained in:
@ -0,0 +1,74 @@
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn::{linear, Linear, VarBuilder};
|
||||
|
||||
// A simplified version of:
|
||||
// https://github.com/huggingface/diffusers/blob/119ad2c3dc8a8fb8446a83f4bf6f20929487b47f/src/diffusers/models/attention_processor.py#L38
|
||||
#[derive(Debug)]
|
||||
pub struct Attention {
|
||||
to_q: Linear,
|
||||
to_k: Linear,
|
||||
to_v: Linear,
|
||||
to_out: Linear,
|
||||
heads: usize,
|
||||
scale: f64,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
pub fn new(query_dim: usize, heads: usize, dim_head: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let inner_dim = dim_head * heads;
|
||||
let scale = 1.0 / f64::sqrt(dim_head as f64);
|
||||
let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?;
|
||||
let to_k = linear(query_dim, inner_dim, vb.pp("to_k"))?;
|
||||
let to_v = linear(query_dim, inner_dim, vb.pp("to_v"))?;
|
||||
let to_out = linear(inner_dim, query_dim, vb.pp("to_out.0"))?;
|
||||
Ok(Self {
|
||||
to_q,
|
||||
to_k,
|
||||
to_v,
|
||||
to_out,
|
||||
scale,
|
||||
heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn batch_to_head_dim(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b_size, seq_len, dim) = xs.dims3()?;
|
||||
xs.reshape((b_size / self.heads, self.heads, seq_len, dim))?
|
||||
.permute((0, 2, 1, 3))?
|
||||
.reshape((b_size / self.heads, seq_len, dim * self.heads))
|
||||
}
|
||||
|
||||
fn head_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b_size, seq_len, dim) = xs.dims3()?;
|
||||
xs.reshape((b_size, seq_len, self.heads, dim / self.heads))?
|
||||
.permute((0, 2, 1, 3))?
|
||||
.reshape((b_size * self.heads, seq_len, dim / self.heads))
|
||||
}
|
||||
|
||||
fn get_attention_scores(&self, query: &Tensor, key: &Tensor) -> Result<Tensor> {
|
||||
let attn_probs = (query.matmul(&key.t()?)? * self.scale)?;
|
||||
candle_nn::ops::softmax_last_dim(&attn_probs)
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let (b_size, channel, h, w) = xs.dims4()?;
|
||||
let xs = xs.reshape((b_size, channel, h * w))?.t()?;
|
||||
|
||||
let query = self.to_q.forward(&xs)?;
|
||||
let key = self.to_k.forward(encoder_hidden_states)?;
|
||||
let value = self.to_v.forward(encoder_hidden_states)?;
|
||||
|
||||
let query = self.head_to_batch_dim(&query)?;
|
||||
let key = self.head_to_batch_dim(&key)?;
|
||||
let value = self.head_to_batch_dim(&value)?;
|
||||
|
||||
let attn_prs = self.get_attention_scores(&query, &key)?;
|
||||
let xs = attn_prs.matmul(&value)?;
|
||||
let xs = self.batch_to_head_dim(&xs)?;
|
||||
|
||||
self.to_out
|
||||
.forward(&xs)?
|
||||
.t()?
|
||||
.reshape((b_size, channel, h, w))
|
||||
}
|
||||
}
|
@ -131,7 +131,7 @@ impl ResBlock {
|
||||
xs + x_res
|
||||
}
|
||||
}
|
||||
use crate::models::stable_diffusion::attention::CrossAttention as Attention;
|
||||
use super::attention_processor::Attention;
|
||||
#[derive(Debug)]
|
||||
pub struct AttnBlock {
|
||||
self_attn: bool,
|
||||
@ -149,7 +149,7 @@ impl AttnBlock {
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm = WLayerNorm::new(c)?;
|
||||
let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?;
|
||||
let attention = Attention::new(c, nhead, c / nhead, vb.pp("attention"))?;
|
||||
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
@ -165,10 +165,10 @@ impl AttnBlock {
|
||||
let kv = if self.self_attn {
|
||||
let (b_size, channel, _, _) = xs.dims4()?;
|
||||
let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?;
|
||||
Tensor::cat(&[&norm_xs, &kv], 1)?
|
||||
Tensor::cat(&[&norm_xs, &kv], 1)?.contiguous()?
|
||||
} else {
|
||||
kv
|
||||
};
|
||||
xs + self.attention.forward(&norm_xs, Some(&kv))
|
||||
xs + self.attention.forward(&norm_xs, &kv)
|
||||
}
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
pub mod attention_processor;
|
||||
pub mod common;
|
||||
pub mod ddpm;
|
||||
pub mod diffnext;
|
||||
|
@ -94,7 +94,7 @@ impl WPrior {
|
||||
xs = block.ts_block.forward(&xs, &r_embed)?;
|
||||
xs = block.attn_block.forward(&xs, &c_embed)?;
|
||||
}
|
||||
let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(1, 2)?;
|
||||
let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(2, 1)?;
|
||||
(x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user