mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Add the attention block. (#846)
* Add the attention block. * Add more to clipnext.
This commit is contained in:
@ -124,3 +124,44 @@ impl ResBlock {
|
||||
xs + x_res
|
||||
}
|
||||
}
|
||||
use crate::models::stable_diffusion::attention::CrossAttention as Attention;
|
||||
#[derive(Debug)]
|
||||
pub struct AttnBlock {
|
||||
self_attn: bool,
|
||||
norm: WLayerNorm,
|
||||
attention: Attention,
|
||||
kv_mapper_lin: candle_nn::Linear,
|
||||
}
|
||||
|
||||
impl AttnBlock {
|
||||
pub fn new(
|
||||
c: usize,
|
||||
c_cond: usize,
|
||||
nhead: usize,
|
||||
self_attn: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
|
||||
let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?;
|
||||
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
norm,
|
||||
attention,
|
||||
kv_mapper_lin,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, kv: &Tensor) -> Result<Tensor> {
|
||||
let kv = candle_nn::ops::silu(kv)?.apply(&self.kv_mapper_lin)?;
|
||||
let norm_xs = self.norm.forward(xs)?;
|
||||
let kv = if self.self_attn {
|
||||
let (b_size, channel, _, _) = xs.dims4()?;
|
||||
let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?;
|
||||
Tensor::cat(&[&norm_xs, &kv], 1)?
|
||||
} else {
|
||||
kv
|
||||
};
|
||||
xs + self.attention.forward(&norm_xs, Some(&kv))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user