Add the attention block. (#846)

* Add the attention block.

* Add more to clipnext.
This commit is contained in:
Laurent Mazare
2023-09-14 16:40:09 +02:00
committed by GitHub
parent 286f01db14
commit a0c6d5548c
4 changed files with 98 additions and 8 deletions

View File

@ -1,5 +1,5 @@
#![allow(unused)]
use super::common::{ResBlock, TimestepBlock};
use super::common::{AttnBlock, ResBlock, TimestepBlock};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
@ -7,7 +7,7 @@ use candle_nn::VarBuilder;
struct Block {
res_block: ResBlock,
ts_block: TimestepBlock,
// TODO: attn_block: super::common::AttnBlock,
attn_block: AttnBlock,
}
#[derive(Debug)]
@ -28,7 +28,7 @@ impl WPrior {
c_cond: usize,
c_r: usize,
depth: usize,
_nhead: usize,
nhead: usize,
vb: VarBuilder,
) -> Result<Self> {
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
@ -40,9 +40,17 @@ impl WPrior {
for index in 0..depth {
let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?;
let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?;
let attn_block = AttnBlock::new(
c,
c,
nhead,
true,
vb.pp(format!("blocks.{}", 3 * index + 2)),
)?;
blocks.push(Block {
res_block,
ts_block,
attn_block,
})
}
Ok(Self {
@ -86,7 +94,7 @@ impl WPrior {
for block in self.blocks.iter() {
xs = block.res_block.forward(&xs, None)?;
xs = block.ts_block.forward(&xs, &r_embed)?;
// TODO: attn
xs = block.attn_block.forward(&xs, &c_embed)?;
}
let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(1, 2)?;
(x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)