Specialized attention module for Wuerstchen. (#890)

* Specialized attention module for Wuerstchen.

* Reshaping ops.

* Attention processor.

* Finish the forward pass.

* Hook the new attention processor.

* Get the prior forward pass to work.

* Make it contiguous.
This commit is contained in:
Laurent Mazare
2023-09-18 21:16:09 +01:00
committed by GitHub
parent 1542e92629
commit 92db8cecd3
4 changed files with 80 additions and 5 deletions

View File

@ -131,7 +131,7 @@ impl ResBlock {
xs + x_res
}
}
use crate::models::stable_diffusion::attention::CrossAttention as Attention;
use super::attention_processor::Attention;
#[derive(Debug)]
pub struct AttnBlock {
self_attn: bool,
@ -149,7 +149,7 @@ impl AttnBlock {
vb: VarBuilder,
) -> Result<Self> {
let norm = WLayerNorm::new(c)?;
let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?;
let attention = Attention::new(c, nhead, c / nhead, vb.pp("attention"))?;
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
Ok(Self {
self_attn,
@ -165,10 +165,10 @@ impl AttnBlock {
let kv = if self.self_attn {
let (b_size, channel, _, _) = xs.dims4()?;
let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?;
Tensor::cat(&[&norm_xs, &kv], 1)?
Tensor::cat(&[&norm_xs, &kv], 1)?.contiguous()?
} else {
kv
};
xs + self.attention.forward(&norm_xs, Some(&kv))
xs + self.attention.forward(&norm_xs, &kv)
}
}