Add the attention block. (#846)

* Add the attention block.

* Add more to clipnext.
This commit is contained in:
Laurent Mazare
2023-09-14 16:40:09 +02:00
committed by GitHub
parent 286f01db14
commit a0c6d5548c
4 changed files with 98 additions and 8 deletions

View File

@ -124,3 +124,44 @@ impl ResBlock {
xs + x_res
}
}
use crate::models::stable_diffusion::attention::CrossAttention as Attention;
#[derive(Debug)]
pub struct AttnBlock {
self_attn: bool,
norm: WLayerNorm,
attention: Attention,
kv_mapper_lin: candle_nn::Linear,
}
impl AttnBlock {
pub fn new(
c: usize,
c_cond: usize,
nhead: usize,
self_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?;
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
Ok(Self {
self_attn,
norm,
attention,
kv_mapper_lin,
})
}
pub fn forward(&self, xs: &Tensor, kv: &Tensor) -> Result<Tensor> {
let kv = candle_nn::ops::silu(kv)?.apply(&self.kv_mapper_lin)?;
let norm_xs = self.norm.forward(xs)?;
let kv = if self.self_attn {
let (b_size, channel, _, _) = xs.dims4()?;
let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?;
Tensor::cat(&[&norm_xs, &kv], 1)?
} else {
kv
};
xs + self.attention.forward(&norm_xs, Some(&kv))
}
}