mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the attention block. (#846)
* Add the attention block. * Add more to clipnext.
This commit is contained in:
@ -78,7 +78,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct CrossAttention {
|
pub struct CrossAttention {
|
||||||
to_q: nn::Linear,
|
to_q: nn::Linear,
|
||||||
to_k: nn::Linear,
|
to_k: nn::Linear,
|
||||||
to_v: nn::Linear,
|
to_v: nn::Linear,
|
||||||
@ -94,7 +94,7 @@ struct CrossAttention {
|
|||||||
|
|
||||||
impl CrossAttention {
|
impl CrossAttention {
|
||||||
// Defaults should be heads = 8, dim_head = 64, context_dim = None
|
// Defaults should be heads = 8, dim_head = 64, context_dim = None
|
||||||
fn new(
|
pub fn new(
|
||||||
vs: nn::VarBuilder,
|
vs: nn::VarBuilder,
|
||||||
query_dim: usize,
|
query_dim: usize,
|
||||||
context_dim: Option<usize>,
|
context_dim: Option<usize>,
|
||||||
@ -205,7 +205,7 @@ impl CrossAttention {
|
|||||||
self.reshape_batch_dim_to_heads(&xs)
|
self.reshape_batch_dim_to_heads(&xs)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let query = self.to_q.forward(xs)?;
|
let query = self.to_q.forward(xs)?;
|
||||||
let context = context.unwrap_or(xs).contiguous()?;
|
let context = context.unwrap_or(xs).contiguous()?;
|
||||||
|
@ -124,3 +124,44 @@ impl ResBlock {
|
|||||||
xs + x_res
|
xs + x_res
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
use crate::models::stable_diffusion::attention::CrossAttention as Attention;
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AttnBlock {
|
||||||
|
self_attn: bool,
|
||||||
|
norm: WLayerNorm,
|
||||||
|
attention: Attention,
|
||||||
|
kv_mapper_lin: candle_nn::Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AttnBlock {
|
||||||
|
pub fn new(
|
||||||
|
c: usize,
|
||||||
|
c_cond: usize,
|
||||||
|
nhead: usize,
|
||||||
|
self_attn: bool,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
|
||||||
|
let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?;
|
||||||
|
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
norm,
|
||||||
|
attention,
|
||||||
|
kv_mapper_lin,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, kv: &Tensor) -> Result<Tensor> {
|
||||||
|
let kv = candle_nn::ops::silu(kv)?.apply(&self.kv_mapper_lin)?;
|
||||||
|
let norm_xs = self.norm.forward(xs)?;
|
||||||
|
let kv = if self.self_attn {
|
||||||
|
let (b_size, channel, _, _) = xs.dims4()?;
|
||||||
|
let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?;
|
||||||
|
Tensor::cat(&[&norm_xs, &kv], 1)?
|
||||||
|
} else {
|
||||||
|
kv
|
||||||
|
};
|
||||||
|
xs + self.attention.forward(&norm_xs, Some(&kv))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -52,4 +52,45 @@ impl ResBlockStageB {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct WDiffNeXt {}
|
pub struct WDiffNeXt {
|
||||||
|
clip_mapper: candle_nn::Linear,
|
||||||
|
effnet_mappers: Vec<candle_nn::Conv2d>,
|
||||||
|
seq_norm: candle_nn::LayerNorm,
|
||||||
|
embedding_conv: candle_nn::Conv2d,
|
||||||
|
embedding_ln: WLayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WDiffNeXt {
|
||||||
|
pub fn new(
|
||||||
|
c_in: usize,
|
||||||
|
c_out: usize,
|
||||||
|
vb: VarBuilder,
|
||||||
|
c_cond: usize,
|
||||||
|
clip_embd: usize,
|
||||||
|
patch_size: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
|
||||||
|
|
||||||
|
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
|
||||||
|
let effnet_mappers = vec![];
|
||||||
|
let cfg = candle_nn::layer_norm::LayerNormConfig {
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let seq_norm = candle_nn::layer_norm(c_cond, cfg, vb.pp("seq_norm"))?;
|
||||||
|
let embedding_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("embedding.1"))?;
|
||||||
|
let embedding_conv = candle_nn::conv2d(
|
||||||
|
c_in * patch_size * patch_size,
|
||||||
|
C_HIDDEN[1],
|
||||||
|
1,
|
||||||
|
Default::default(),
|
||||||
|
vb.pp("embedding.2"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
clip_mapper,
|
||||||
|
effnet_mappers,
|
||||||
|
seq_norm,
|
||||||
|
embedding_conv,
|
||||||
|
embedding_ln,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#![allow(unused)]
|
#![allow(unused)]
|
||||||
use super::common::{ResBlock, TimestepBlock};
|
use super::common::{AttnBlock, ResBlock, TimestepBlock};
|
||||||
use candle::{DType, Module, Result, Tensor, D};
|
use candle::{DType, Module, Result, Tensor, D};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
@ -7,7 +7,7 @@ use candle_nn::VarBuilder;
|
|||||||
struct Block {
|
struct Block {
|
||||||
res_block: ResBlock,
|
res_block: ResBlock,
|
||||||
ts_block: TimestepBlock,
|
ts_block: TimestepBlock,
|
||||||
// TODO: attn_block: super::common::AttnBlock,
|
attn_block: AttnBlock,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -28,7 +28,7 @@ impl WPrior {
|
|||||||
c_cond: usize,
|
c_cond: usize,
|
||||||
c_r: usize,
|
c_r: usize,
|
||||||
depth: usize,
|
depth: usize,
|
||||||
_nhead: usize,
|
nhead: usize,
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
|
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
|
||||||
@ -40,9 +40,17 @@ impl WPrior {
|
|||||||
for index in 0..depth {
|
for index in 0..depth {
|
||||||
let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?;
|
let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?;
|
||||||
let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?;
|
let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?;
|
||||||
|
let attn_block = AttnBlock::new(
|
||||||
|
c,
|
||||||
|
c,
|
||||||
|
nhead,
|
||||||
|
true,
|
||||||
|
vb.pp(format!("blocks.{}", 3 * index + 2)),
|
||||||
|
)?;
|
||||||
blocks.push(Block {
|
blocks.push(Block {
|
||||||
res_block,
|
res_block,
|
||||||
ts_block,
|
ts_block,
|
||||||
|
attn_block,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -86,7 +94,7 @@ impl WPrior {
|
|||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
xs = block.res_block.forward(&xs, None)?;
|
xs = block.res_block.forward(&xs, None)?;
|
||||||
xs = block.ts_block.forward(&xs, &r_embed)?;
|
xs = block.ts_block.forward(&xs, &r_embed)?;
|
||||||
// TODO: attn
|
xs = block.attn_block.forward(&xs, &c_embed)?;
|
||||||
}
|
}
|
||||||
let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(1, 2)?;
|
let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(1, 2)?;
|
||||||
(x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)
|
(x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)
|
||||||
|
Reference in New Issue
Block a user