Files
candle/candle-transformers/src/models/based.rs
zachcp a3f200e369 Module Docs (#2620)
* update bert docs

* update based

* update bigcode

* add pixtral

* add flux as well
2024-11-16 09:09:17 +01:00

589 lines
18 KiB
Rust

//! Based from the Stanford Hazy Research group.
//!
//! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024
//! - Simple linear attention language models balance the recall-throughput tradeoff. [Arxiv](https://arxiv.org/abs/2402.18668)
//! - [Github Rep](https://github.com/HazyResearch/based)
//! - [Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based)
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{
conv1d_no_bias, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Conv1d, Conv1dConfig,
Func, Linear, RmsNorm, VarBuilder,
};
use std::sync::Arc;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LinearAttentionFeatureMapConfig {
input_dim: usize,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LinearAttentionConfig {
num_heads: usize,
feature_dim: usize,
feature_map: LinearAttentionFeatureMapConfig,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct SlidingWindowAttentionConfig {
num_heads: usize,
window_size: usize,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
vocab_size: usize,
#[serde(rename = "n_embd")]
hidden_size: usize,
#[serde(rename = "n_inner")]
intermediate_size: usize,
#[serde(rename = "n_layer")]
num_hidden_layers: usize,
#[serde(rename = "n_head")]
num_attention_heads: usize,
layer_norm_epsilon: f64,
#[serde(default = "default_rope", rename = "rotary_emb_base")]
rope_theta: f64,
alt_mixer_layers: Vec<usize>,
alt_mixer_2_layers: Vec<usize>,
#[serde(rename = "alt_mixer")]
la: LinearAttentionConfig,
#[serde(rename = "alt_mixer_2")]
swa: SlidingWindowAttentionConfig,
}
fn default_rope() -> f64 {
10_000.0
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
fc1: Linear,
fc2: Linear,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let fc1 = linear_no_bias(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("fc1"))?;
let fc2 = linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
Ok(Self { fc1, fc2 })
}
}
// Swiglu implementation.
// Not using Activation::Swiglu because this has the gate and y arguments switched compared to the version in candle-nn/src/ops.rs
fn swiglu(xs: &Tensor) -> Result<Tensor> {
let xs = xs.chunk(2, D::Minus1)?;
&xs[1].silu()? * &xs[0]
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.fc1)?;
let xs = swiglu(&xs)?;
let xs = xs.apply(&self.fc2)?;
Ok(xs)
}
}
// A gated convolutional block.
#[derive(Debug, Clone)]
struct BasedConv {
in_proj: Linear,
out_proj: Linear,
conv: Conv1d,
state: Tensor,
}
impl BasedConv {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dim = cfg.hidden_size * 2;
let conv1d_cfg = Conv1dConfig {
groups: dim,
padding: 2,
..Default::default()
};
let in_proj = linear(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("in_proj"))?;
let out_proj = linear(dim, cfg.hidden_size, vb.pp("out_proj"))?;
let conv = conv1d_no_bias(dim, dim, 3, conv1d_cfg, vb.pp("conv.conv"))?;
let state = Tensor::zeros((1, dim, 3), vb.dtype(), vb.device())?;
Ok(Self {
in_proj,
out_proj,
conv,
state,
})
}
fn step(&mut self, xs: &Tensor) -> Result<Tensor> {
self.state = self.state.roll(-1, D::Minus1)?;
let (_, _, l) = self.state.dims3()?;
self.state = self.state.narrow(D::Minus1, 0, l - 1)?;
self.state = Tensor::cat(&[&self.state, &xs.transpose(1, 2)?], 2)?;
let xs = (&self.state * self.conv.weight().permute((1, 0, 2))?)?
.sum_keepdim(0)?
.sum(D::Minus1)?;
let xs = xs.unsqueeze(1)?;
Ok(xs)
}
fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let xs = xs.apply(&self.in_proj)?;
let us = xs.chunk(2, D::Minus1)?;
let (_b, l, _d) = us[0].dims3()?;
let u_conv = if seqlen_offset > 0 {
self.step(&us[0])?
} else {
let k = std::cmp::min(3, l);
self.state = self.state.narrow(D::Minus1, 0, 3 - k)?;
let xs = us[0].narrow(1, l - k, k)?.transpose(1, 2)?;
self.state = Tensor::cat(&[&self.state, &xs], 2)?;
us[0]
.transpose(1, 2)?
.apply(&self.conv)?
.narrow(D::Minus1, 0, l)?
.transpose(1, 2)?
};
let u_conv = u_conv.silu()?;
let v = u_conv.broadcast_mul(&us[1])?;
let xs = v.apply(&self.out_proj)?;
Ok(xs)
}
}
// Linear attention approximating softmax using second order Taylor polynomials.
#[derive(Debug, Clone)]
struct LinearAttention {
proj_q: Linear,
proj_k: Linear,
proj_v: Linear,
out_proj: Linear,
feature_dim: usize,
num_heads: usize,
input_dim: usize,
k_state: Tensor,
kv_state: Tensor,
}
impl LinearAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let input_dim = cfg.la.feature_map.input_dim;
let out_proj = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("out_proj"))?;
let proj_k = linear_no_bias(
cfg.hidden_size,
cfg.la.num_heads * cfg.la.feature_dim,
vb.pp("proj_k"),
)?;
let proj_q = linear_no_bias(
cfg.hidden_size,
cfg.la.num_heads * cfg.la.feature_dim,
vb.pp("proj_q"),
)?;
let proj_v = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("proj_v"))?;
let expanded_size = cfg.la.feature_dim.pow(2) + cfg.la.feature_dim + 1;
let k_state = Tensor::zeros(
(1, cfg.la.num_heads, 1, 1, expanded_size),
vb.dtype(),
vb.device(),
)?;
let kv_state = Tensor::zeros(
(1, cfg.la.num_heads, cfg.la.feature_dim, expanded_size),
vb.dtype(),
vb.device(),
)?;
Ok(Self {
proj_q,
proj_k,
proj_v,
out_proj,
feature_dim: cfg.la.feature_dim,
num_heads: cfg.la.num_heads,
input_dim,
k_state,
kv_state,
})
}
fn taylor_expansion(&self) -> Result<Func<'static>> {
let r2 = std::f64::consts::SQRT_2;
let rd = (self.input_dim as f64).sqrt();
let rrd = rd.sqrt();
Ok(Func::new(move |xs| {
let dims = xs.dims();
let mut d = dims.to_vec();
if let Some(last) = d.last_mut() {
*last = 1;
};
let x = xs
.unsqueeze(D::Minus1)?
.broadcast_mul(&xs.unsqueeze(D::Minus2)?)?;
let x = (x.flatten_from(D::Minus2)? / r2)?;
let o = Tensor::ones(d, xs.dtype(), xs.device())?;
let x = Tensor::cat(&[o, (xs / rrd)?, (&x / rd)?], D::Minus1)?;
Ok(x)
}))
}
fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let eps = 1e-12;
let feature_map = self.taylor_expansion()?;
let (b, l, d) = xs.dims3()?;
let q = xs.apply(&self.proj_q)?;
let k = xs.apply(&self.proj_k)?;
let v = xs.apply(&self.proj_v)?;
let q = q
.reshape((b, l, self.num_heads, self.feature_dim))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b, l, self.num_heads, self.feature_dim))?
.transpose(1, 2)?
.contiguous()?;
let v = v
.reshape((b, l, self.num_heads, d / self.num_heads))?
.transpose(1, 2)?
.contiguous()?;
let q = feature_map.forward(&q)?;
let k = feature_map.forward(&k)?;
let y = if seqlen_offset > 0 {
let (_b, _h, l, _d) = k.dims4()?;
let q = q.unsqueeze(D::Minus2)?;
let k = k.unsqueeze(D::Minus2)?;
let v = v.unsqueeze(D::Minus1)?;
let kn = k.narrow(D::Minus1, l - 1, 1)?;
let vn = v.narrow(D::Minus1, l - 1, 1)?;
self.k_state = self.k_state.broadcast_add(&kn)?;
self.kv_state = self.kv_state.broadcast_add(&kn.broadcast_mul(&vn)?)?;
let num = q.broadcast_mul(&self.kv_state)?.sum(D::Minus1)?;
let den = (q.broadcast_mul(&self.k_state)?.sum(D::Minus1)? + eps)?;
num.broadcast_div(&den)?
} else {
self.k_state = k.sum(2)?.unsqueeze(2)?.unsqueeze(3)?;
self.kv_state = k
.transpose(2, 3)?
.matmul(&v)?
.transpose(2, 3)?
.unsqueeze(2)?;
let aqk = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?;
let tril = Tensor::tril2(l, aqk.dtype(), aqk.device())?;
let aqk = aqk.broadcast_mul(&tril)?.matmul(&v)?;
let z = (1f64 / (q.mul(&k.cumsum(2)?)?.sum(D::Minus1)? + eps)?)?;
aqk.broadcast_mul(&z.unsqueeze(D::Minus1)?)?
};
let (b, h, l, d) = y.dims4()?;
let y = y.permute((0, 2, 1, 3))?.reshape((b, l, h * d))?;
let y = self.out_proj.forward(&y)?;
Ok(y)
}
}
// Rotary embeddings used in local attention.
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = 2048; // Hardcoded, missing from config.
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
// Local attention using a small sliding window.
#[derive(Debug, Clone)]
struct SlidingWindowAttention {
wqkv: Linear,
out_proj: Linear,
num_heads: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
}
impl SlidingWindowAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let num_heads = cfg.swa.num_heads;
let head_dim = hidden_size / num_heads;
let out_proj = linear_no_bias(hidden_size, hidden_size, vb.pp("out_proj"))?;
let wqkv = linear_no_bias(hidden_size, hidden_size * 3, vb.pp("Wqkv"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
Ok(Self {
wqkv,
out_proj,
hidden_size,
num_heads,
head_dim,
rotary_emb,
kv_cache: None,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let qkv = xs.apply(&self.wqkv)?;
let qkv = qkv.reshape((b_sz, q_len, 3, (), self.head_dim))?;
let q = qkv.i((.., .., 0))?;
let k = qkv.i((.., .., 1))?;
let v = qkv.i((.., .., 2))?;
let q = q
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let (q, k) = self
.rotary_emb
.apply_rotary_emb_qkv(&q, &k, seqlen_offset)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((prev_k, prev_v)) => {
let k = Tensor::cat(&[prev_k, &k], 2)?;
let v = Tensor::cat(&[prev_v, &v], 2)?;
(k, v)
}
};
self.kv_cache = Some((k.clone(), v.clone()));
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&v)?;
let out = attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.out_proj)?;
Ok(out)
}
}
// The model layers use three types of mixers.
#[derive(Debug, Clone)]
enum SequenceMixer {
Based(BasedConv),
Linear(LinearAttention),
Sliding(SlidingWindowAttention),
}
impl SequenceMixer {
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
pos: usize,
) -> Result<Tensor> {
match self {
Self::Based(b) => b.forward(xs, pos),
Self::Linear(b) => b.forward(xs, pos),
Self::Sliding(b) => b.forward(xs, attention_mask, pos),
}
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
mlp: MLP,
norm1: RmsNorm,
norm2: RmsNorm,
mixer: SequenceMixer,
}
impl DecoderLayer {
fn new(layer_idx: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let norm1 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?;
let norm2 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?;
let l_attn = cfg.alt_mixer_layers.contains(&layer_idx);
let sw_attn = cfg.alt_mixer_2_layers.contains(&layer_idx);
let mixer = if l_attn {
SequenceMixer::Linear(LinearAttention::new(cfg, vb.pp("mixer"))?)
} else if sw_attn {
SequenceMixer::Sliding(SlidingWindowAttention::new(cfg, vb.pp("mixer"))?)
} else {
SequenceMixer::Based(BasedConv::new(cfg, vb.pp("mixer"))?)
};
Ok(Self {
mlp,
norm1,
norm2,
mixer,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.norm1.forward(xs)?;
let xs = self.mixer.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.norm2)?.apply(&self.mlp)?;
residual + xs
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: super::with_tracing::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
sliding_window: usize,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vocab_size = cfg.vocab_size + (8 - cfg.vocab_size % 8) % 8;
let lm_head = linear_no_bias(cfg.hidden_size, vocab_size, vb.pp("lm_head"))?;
let embed_tokens = super::with_tracing::Embedding::from_weights(lm_head.weight().clone())?;
let vb_m = vb.pp("transformer");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(layer_idx, cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?;
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
sliding_window: cfg.swa.window_size,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let sliding_window = self.sliding_window / 2;
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
}