mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Add the attention block. (#846)
* Add the attention block. * Add more to clipnext.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
#![allow(unused)]
|
||||
use super::common::{ResBlock, TimestepBlock};
|
||||
use super::common::{AttnBlock, ResBlock, TimestepBlock};
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
@ -7,7 +7,7 @@ use candle_nn::VarBuilder;
|
||||
struct Block {
|
||||
res_block: ResBlock,
|
||||
ts_block: TimestepBlock,
|
||||
// TODO: attn_block: super::common::AttnBlock,
|
||||
attn_block: AttnBlock,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -28,7 +28,7 @@ impl WPrior {
|
||||
c_cond: usize,
|
||||
c_r: usize,
|
||||
depth: usize,
|
||||
_nhead: usize,
|
||||
nhead: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
|
||||
@ -40,9 +40,17 @@ impl WPrior {
|
||||
for index in 0..depth {
|
||||
let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?;
|
||||
let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?;
|
||||
let attn_block = AttnBlock::new(
|
||||
c,
|
||||
c,
|
||||
nhead,
|
||||
true,
|
||||
vb.pp(format!("blocks.{}", 3 * index + 2)),
|
||||
)?;
|
||||
blocks.push(Block {
|
||||
res_block,
|
||||
ts_block,
|
||||
attn_block,
|
||||
})
|
||||
}
|
||||
Ok(Self {
|
||||
@ -86,7 +94,7 @@ impl WPrior {
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.res_block.forward(&xs, None)?;
|
||||
xs = block.ts_block.forward(&xs, &r_embed)?;
|
||||
// TODO: attn
|
||||
xs = block.attn_block.forward(&xs, &c_embed)?;
|
||||
}
|
||||
let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(1, 2)?;
|
||||
(x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)
|
||||
|
Reference in New Issue
Block a user