mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Build alibi bias. (#1115)
* Build alibi bias. * Apply the alibi attention bias. * Add the replit-code example.
This commit is contained in:
@ -15,7 +15,9 @@ pub struct Config {
|
||||
pub(crate) max_seq_len: usize,
|
||||
pub(crate) vocab_size: usize,
|
||||
pub(crate) kv_n_heads: usize,
|
||||
// pub(crate) attn_config: AttnConfig,
|
||||
pub(crate) attn_prefix_lm: bool,
|
||||
pub(crate) attn_alibi: bool,
|
||||
pub(crate) attn_alibi_bias_max: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -28,8 +30,15 @@ impl Config {
|
||||
max_seq_len: 4096,
|
||||
vocab_size: 32768,
|
||||
kv_n_heads: 8,
|
||||
attn_prefix_lm: false,
|
||||
attn_alibi: true,
|
||||
attn_alibi_bias_max: 8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_causal(&self) -> bool {
|
||||
!self.attn_prefix_lm
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -42,6 +51,7 @@ struct GroupedQueryAttention {
|
||||
d_model: usize,
|
||||
n_heads: usize,
|
||||
kv_n_heads: usize,
|
||||
attn_bias: Tensor,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
@ -52,6 +62,7 @@ impl GroupedQueryAttention {
|
||||
let head_dim = cfg.d_model / cfg.n_heads;
|
||||
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||
let out_proj = linear(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
|
||||
let attn_bias = build_alibi_bias(cfg)?.to_device(vb.device())?;
|
||||
Ok(Self {
|
||||
wqkv,
|
||||
out_proj,
|
||||
@ -61,6 +72,7 @@ impl GroupedQueryAttention {
|
||||
d_model: cfg.d_model,
|
||||
n_heads: cfg.n_heads,
|
||||
kv_n_heads: cfg.kv_n_heads,
|
||||
attn_bias,
|
||||
span: tracing::span!(tracing::Level::TRACE, "gqa"),
|
||||
})
|
||||
}
|
||||
@ -94,7 +106,23 @@ impl GroupedQueryAttention {
|
||||
let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?;
|
||||
let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?;
|
||||
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
|
||||
// TODO: attn_bias, alibi
|
||||
let attn_bias = {
|
||||
let s_q = query.dim(D::Minus2)?;
|
||||
let s_k = key.dim(D::Minus1)?;
|
||||
let (_, _, a_q, a_k) = self.attn_bias.dims4()?;
|
||||
self.attn_bias
|
||||
.narrow(2, a_q - s_q, s_q)?
|
||||
.narrow(3, a_k - s_k, s_k)?
|
||||
};
|
||||
let attn_weights = (attn_weights + attn_bias)?;
|
||||
let attn_weights = match mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => masked_fill(
|
||||
&attn_weights,
|
||||
&mask.broadcast_left(b_size * self.n_heads)?,
|
||||
f32::NEG_INFINITY,
|
||||
)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights
|
||||
.matmul(&value)?
|
||||
@ -172,15 +200,49 @@ impl MPTBlock {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
|
||||
let full = !cfg.is_causal();
|
||||
let seq_len = cfg.max_seq_len;
|
||||
let alibi_bias = Tensor::arange(1 - seq_len as i64, 1, &Device::Cpu)?;
|
||||
let alibi_bias = if full {
|
||||
let a1 = alibi_bias.reshape((1, 1, 1, seq_len))?;
|
||||
let a2 = alibi_bias.reshape((1, 1, seq_len, 1))?;
|
||||
a1.broadcast_sub(&a2)?.abs()?.neg()?
|
||||
} else {
|
||||
alibi_bias.reshape((1, 1, 1, seq_len))?
|
||||
};
|
||||
let mut n_heads2 = 1;
|
||||
while 2 * n_heads2 <= cfg.n_heads {
|
||||
n_heads2 *= 2
|
||||
}
|
||||
let slopes = (1..=n_heads2)
|
||||
.map(|v| 1f32 / 2f32.powf((v * cfg.attn_alibi_bias_max) as f32 / n_heads2 as f32))
|
||||
.collect::<Vec<_>>();
|
||||
let slopes = if n_heads2 == cfg.n_heads {
|
||||
slopes
|
||||
} else {
|
||||
slopes
|
||||
.iter()
|
||||
.skip(1)
|
||||
.step_by(2)
|
||||
.chain(slopes.iter().step_by(2))
|
||||
.take(cfg.n_heads)
|
||||
.cloned()
|
||||
.collect::<Vec<f32>>()
|
||||
};
|
||||
let slopes = Tensor::new(slopes, &Device::Cpu)?;
|
||||
alibi_bias.broadcast_mul(&slopes)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Model {
|
||||
pub struct Model {
|
||||
wte: candle_nn::Embedding,
|
||||
blocks: Vec<MPTBlock>,
|
||||
norm_f: LayerNorm,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let wte = candle_nn::embedding(cfg.vocab_size, cfg.d_model, vb.pp("wte"))?;
|
||||
let vb_b = vb.pp("blocks");
|
||||
let mut blocks = Vec::with_capacity(cfg.n_layers);
|
||||
@ -196,7 +258,33 @@ impl Model {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
todo!()
|
||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b_size, seq_len) = xs.dims2()?;
|
||||
let mut xs = xs.apply(&self.wte)?;
|
||||
let mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(seq_len, xs.device())?)
|
||||
};
|
||||
for block in self.blocks.iter_mut() {
|
||||
xs = block.forward(&xs, mask.as_ref())?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.matmul(&self.wte.embeddings().t()?)?
|
||||
.squeeze(1)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device)
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
Reference in New Issue
Block a user