Add the attention block. (#846)

* Add the attention block.

* Add more to clipnext.
This commit is contained in:
Laurent Mazare
2023-09-14 16:40:09 +02:00
committed by GitHub
parent 286f01db14
commit a0c6d5548c
4 changed files with 98 additions and 8 deletions

View File

@ -78,7 +78,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
} }
#[derive(Debug)] #[derive(Debug)]
struct CrossAttention { pub struct CrossAttention {
to_q: nn::Linear, to_q: nn::Linear,
to_k: nn::Linear, to_k: nn::Linear,
to_v: nn::Linear, to_v: nn::Linear,
@ -94,7 +94,7 @@ struct CrossAttention {
impl CrossAttention { impl CrossAttention {
// Defaults should be heads = 8, dim_head = 64, context_dim = None // Defaults should be heads = 8, dim_head = 64, context_dim = None
fn new( pub fn new(
vs: nn::VarBuilder, vs: nn::VarBuilder,
query_dim: usize, query_dim: usize,
context_dim: Option<usize>, context_dim: Option<usize>,
@ -205,7 +205,7 @@ impl CrossAttention {
self.reshape_batch_dim_to_heads(&xs) self.reshape_batch_dim_to_heads(&xs)
} }
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let query = self.to_q.forward(xs)?; let query = self.to_q.forward(xs)?;
let context = context.unwrap_or(xs).contiguous()?; let context = context.unwrap_or(xs).contiguous()?;

View File

@ -124,3 +124,44 @@ impl ResBlock {
xs + x_res xs + x_res
} }
} }
use crate::models::stable_diffusion::attention::CrossAttention as Attention;
#[derive(Debug)]
pub struct AttnBlock {
self_attn: bool,
norm: WLayerNorm,
attention: Attention,
kv_mapper_lin: candle_nn::Linear,
}
impl AttnBlock {
pub fn new(
c: usize,
c_cond: usize,
nhead: usize,
self_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?;
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
Ok(Self {
self_attn,
norm,
attention,
kv_mapper_lin,
})
}
pub fn forward(&self, xs: &Tensor, kv: &Tensor) -> Result<Tensor> {
let kv = candle_nn::ops::silu(kv)?.apply(&self.kv_mapper_lin)?;
let norm_xs = self.norm.forward(xs)?;
let kv = if self.self_attn {
let (b_size, channel, _, _) = xs.dims4()?;
let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?;
Tensor::cat(&[&norm_xs, &kv], 1)?
} else {
kv
};
xs + self.attention.forward(&norm_xs, Some(&kv))
}
}

View File

@ -52,4 +52,45 @@ impl ResBlockStageB {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct WDiffNeXt {} pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
effnet_mappers: Vec<candle_nn::Conv2d>,
seq_norm: candle_nn::LayerNorm,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
}
impl WDiffNeXt {
pub fn new(
c_in: usize,
c_out: usize,
vb: VarBuilder,
c_cond: usize,
clip_embd: usize,
patch_size: usize,
) -> Result<Self> {
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
let effnet_mappers = vec![];
let cfg = candle_nn::layer_norm::LayerNormConfig {
..Default::default()
};
let seq_norm = candle_nn::layer_norm(c_cond, cfg, vb.pp("seq_norm"))?;
let embedding_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("embedding.1"))?;
let embedding_conv = candle_nn::conv2d(
c_in * patch_size * patch_size,
C_HIDDEN[1],
1,
Default::default(),
vb.pp("embedding.2"),
)?;
Ok(Self {
clip_mapper,
effnet_mappers,
seq_norm,
embedding_conv,
embedding_ln,
})
}
}

View File

@ -1,5 +1,5 @@
#![allow(unused)] #![allow(unused)]
use super::common::{ResBlock, TimestepBlock}; use super::common::{AttnBlock, ResBlock, TimestepBlock};
use candle::{DType, Module, Result, Tensor, D}; use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
@ -7,7 +7,7 @@ use candle_nn::VarBuilder;
struct Block { struct Block {
res_block: ResBlock, res_block: ResBlock,
ts_block: TimestepBlock, ts_block: TimestepBlock,
// TODO: attn_block: super::common::AttnBlock, attn_block: AttnBlock,
} }
#[derive(Debug)] #[derive(Debug)]
@ -28,7 +28,7 @@ impl WPrior {
c_cond: usize, c_cond: usize,
c_r: usize, c_r: usize,
depth: usize, depth: usize,
_nhead: usize, nhead: usize,
vb: VarBuilder, vb: VarBuilder,
) -> Result<Self> { ) -> Result<Self> {
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?; let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
@ -40,9 +40,17 @@ impl WPrior {
for index in 0..depth { for index in 0..depth {
let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?; let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?;
let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?; let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?;
let attn_block = AttnBlock::new(
c,
c,
nhead,
true,
vb.pp(format!("blocks.{}", 3 * index + 2)),
)?;
blocks.push(Block { blocks.push(Block {
res_block, res_block,
ts_block, ts_block,
attn_block,
}) })
} }
Ok(Self { Ok(Self {
@ -86,7 +94,7 @@ impl WPrior {
for block in self.blocks.iter() { for block in self.blocks.iter() {
xs = block.res_block.forward(&xs, None)?; xs = block.res_block.forward(&xs, None)?;
xs = block.ts_block.forward(&xs, &r_embed)?; xs = block.ts_block.forward(&xs, &r_embed)?;
// TODO: attn xs = block.attn_block.forward(&xs, &c_embed)?;
} }
let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(1, 2)?; let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(1, 2)?;
(x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5) (x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)