mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Add Based LLM from Hazy Research. (#2411)
This commit is contained in:
20
candle-examples/examples/based/README.md
Normal file
20
candle-examples/examples/based/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-based
|
||||
|
||||
Experimental, not instruction-tuned small LLM from the Hazy Research group, combining local and linear attention layers.
|
||||
|
||||
[Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based)
|
||||
|
||||
[Simple linear attention language models balance the recall-throughput tradeoff](https://arxiv.org/abs/2402.18668)
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example based --release -- --prompt "Flying monkeys are" --which 1b-50b --sample-len 100
|
||||
|
||||
Flying monkeys are a common sight in the wild, but they are also a threat to humans.
|
||||
|
||||
The new study, published today (July 31) in the journal Science Advances, shows that the monkeys are using their brains to solve the problem of how to get around the problem.
|
||||
|
||||
"We found that the monkeys were using a strategy called 'cognitive mapping' - they would use their brains to map out the route ahead," says lead author Dr. David J. Smith from the University of California
|
||||
|
||||
```
|
275
candle-examples/examples/based/main.rs
Normal file
275
candle-examples/examples/based/main.rs
Normal file
@ -0,0 +1,275 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::based::Model;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "360m")]
|
||||
W360m,
|
||||
#[value(name = "1b")]
|
||||
W1b,
|
||||
#[value(name = "1b-50b")]
|
||||
W1b50b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "refs/pr/1")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
#[arg(long, default_value = "360m")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => match args.which {
|
||||
Which::W360m => "hazyresearch/based-360m".to_string(),
|
||||
Which::W1b => "hazyresearch/based-1b".to_string(),
|
||||
Which::W1b50b => "hazyresearch/based-1b-50b".to_string(),
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let config_file = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => vec![repo.get("model.safetensors")?],
|
||||
};
|
||||
|
||||
let repo = api.model("openai-community/gpt2".to_string());
|
||||
let tokenizer_file = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
|
||||
let mut vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
if args.which == Which::W1b50b {
|
||||
vb = vb.pp("model");
|
||||
};
|
||||
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
589
candle-transformers/src/models/based.rs
Normal file
589
candle-transformers/src/models/based.rs
Normal file
@ -0,0 +1,589 @@
|
||||
//! Based from the Stanford Hazy Research group.
|
||||
//!
|
||||
//! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024
|
||||
//! <https://arxiv.org/abs/2402.18668>
|
||||
|
||||
//! Original code:
|
||||
//! https://github.com/HazyResearch/based
|
||||
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
conv1d_no_bias, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Conv1d, Conv1dConfig,
|
||||
Func, Linear, RmsNorm, VarBuilder,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct LinearAttentionFeatureMapConfig {
|
||||
input_dim: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct LinearAttentionConfig {
|
||||
num_heads: usize,
|
||||
feature_dim: usize,
|
||||
feature_map: LinearAttentionFeatureMapConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct SlidingWindowAttentionConfig {
|
||||
num_heads: usize,
|
||||
window_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
#[serde(rename = "n_embd")]
|
||||
hidden_size: usize,
|
||||
#[serde(rename = "n_inner")]
|
||||
intermediate_size: usize,
|
||||
#[serde(rename = "n_layer")]
|
||||
num_hidden_layers: usize,
|
||||
#[serde(rename = "n_head")]
|
||||
num_attention_heads: usize,
|
||||
|
||||
layer_norm_epsilon: f64,
|
||||
#[serde(default = "default_rope", rename = "rotary_emb_base")]
|
||||
rope_theta: f64,
|
||||
|
||||
alt_mixer_layers: Vec<usize>,
|
||||
alt_mixer_2_layers: Vec<usize>,
|
||||
#[serde(rename = "alt_mixer")]
|
||||
la: LinearAttentionConfig,
|
||||
#[serde(rename = "alt_mixer_2")]
|
||||
swa: SlidingWindowAttentionConfig,
|
||||
}
|
||||
|
||||
fn default_rope() -> f64 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let fc1 = linear_no_bias(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("fc1"))?;
|
||||
let fc2 = linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
|
||||
Ok(Self { fc1, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
// Swiglu implementation.
|
||||
// Not using Activation::Swiglu because this has the gate and y arguments switched compared to the version in candle-nn/src/ops.rs
|
||||
fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.chunk(2, D::Minus1)?;
|
||||
&xs[1].silu()? * &xs[0]
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.fc1)?;
|
||||
let xs = swiglu(&xs)?;
|
||||
let xs = xs.apply(&self.fc2)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// A gated convolutional block.
|
||||
#[derive(Debug, Clone)]
|
||||
struct BasedConv {
|
||||
in_proj: Linear,
|
||||
out_proj: Linear,
|
||||
conv: Conv1d,
|
||||
state: Tensor,
|
||||
}
|
||||
|
||||
impl BasedConv {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dim = cfg.hidden_size * 2;
|
||||
|
||||
let conv1d_cfg = Conv1dConfig {
|
||||
groups: dim,
|
||||
padding: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let in_proj = linear(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("in_proj"))?;
|
||||
let out_proj = linear(dim, cfg.hidden_size, vb.pp("out_proj"))?;
|
||||
let conv = conv1d_no_bias(dim, dim, 3, conv1d_cfg, vb.pp("conv.conv"))?;
|
||||
let state = Tensor::zeros((1, dim, 3), vb.dtype(), vb.device())?;
|
||||
Ok(Self {
|
||||
in_proj,
|
||||
out_proj,
|
||||
conv,
|
||||
state,
|
||||
})
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.state = self.state.roll(-1, D::Minus1)?;
|
||||
let (_, _, l) = self.state.dims3()?;
|
||||
self.state = self.state.narrow(D::Minus1, 0, l - 1)?;
|
||||
self.state = Tensor::cat(&[&self.state, &xs.transpose(1, 2)?], 2)?;
|
||||
|
||||
let xs = (&self.state * self.conv.weight().permute((1, 0, 2))?)?
|
||||
.sum_keepdim(0)?
|
||||
.sum(D::Minus1)?;
|
||||
|
||||
let xs = xs.unsqueeze(1)?;
|
||||
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.in_proj)?;
|
||||
let us = xs.chunk(2, D::Minus1)?;
|
||||
let (_b, l, _d) = us[0].dims3()?;
|
||||
let u_conv = if seqlen_offset > 0 {
|
||||
self.step(&us[0])?
|
||||
} else {
|
||||
let k = std::cmp::min(3, l);
|
||||
self.state = self.state.narrow(D::Minus1, 0, 3 - k)?;
|
||||
let xs = us[0].narrow(1, l - k, k)?.transpose(1, 2)?;
|
||||
self.state = Tensor::cat(&[&self.state, &xs], 2)?;
|
||||
|
||||
us[0]
|
||||
.transpose(1, 2)?
|
||||
.apply(&self.conv)?
|
||||
.narrow(D::Minus1, 0, l)?
|
||||
.transpose(1, 2)?
|
||||
};
|
||||
|
||||
let u_conv = u_conv.silu()?;
|
||||
let v = u_conv.broadcast_mul(&us[1])?;
|
||||
let xs = v.apply(&self.out_proj)?;
|
||||
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// Linear attention approximating softmax using second order Taylor polynomials.
|
||||
#[derive(Debug, Clone)]
|
||||
struct LinearAttention {
|
||||
proj_q: Linear,
|
||||
proj_k: Linear,
|
||||
proj_v: Linear,
|
||||
out_proj: Linear,
|
||||
feature_dim: usize,
|
||||
num_heads: usize,
|
||||
input_dim: usize,
|
||||
k_state: Tensor,
|
||||
kv_state: Tensor,
|
||||
}
|
||||
|
||||
impl LinearAttention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let input_dim = cfg.la.feature_map.input_dim;
|
||||
let out_proj = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("out_proj"))?;
|
||||
let proj_k = linear_no_bias(
|
||||
cfg.hidden_size,
|
||||
cfg.la.num_heads * cfg.la.feature_dim,
|
||||
vb.pp("proj_k"),
|
||||
)?;
|
||||
let proj_q = linear_no_bias(
|
||||
cfg.hidden_size,
|
||||
cfg.la.num_heads * cfg.la.feature_dim,
|
||||
vb.pp("proj_q"),
|
||||
)?;
|
||||
|
||||
let proj_v = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("proj_v"))?;
|
||||
let expanded_size = cfg.la.feature_dim.pow(2) + cfg.la.feature_dim + 1;
|
||||
let k_state = Tensor::zeros(
|
||||
(1, cfg.la.num_heads, 1, 1, expanded_size),
|
||||
vb.dtype(),
|
||||
vb.device(),
|
||||
)?;
|
||||
let kv_state = Tensor::zeros(
|
||||
(1, cfg.la.num_heads, cfg.la.feature_dim, expanded_size),
|
||||
vb.dtype(),
|
||||
vb.device(),
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
proj_q,
|
||||
proj_k,
|
||||
proj_v,
|
||||
out_proj,
|
||||
feature_dim: cfg.la.feature_dim,
|
||||
num_heads: cfg.la.num_heads,
|
||||
input_dim,
|
||||
k_state,
|
||||
kv_state,
|
||||
})
|
||||
}
|
||||
|
||||
fn taylor_expansion(&self) -> Result<Func<'static>> {
|
||||
let r2 = std::f64::consts::SQRT_2;
|
||||
let rd = (self.input_dim as f64).sqrt();
|
||||
let rrd = rd.sqrt();
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let dims = xs.dims();
|
||||
let mut d = dims.to_vec();
|
||||
if let Some(last) = d.last_mut() {
|
||||
*last = 1;
|
||||
};
|
||||
|
||||
let x = xs
|
||||
.unsqueeze(D::Minus1)?
|
||||
.broadcast_mul(&xs.unsqueeze(D::Minus2)?)?;
|
||||
let x = (x.flatten_from(D::Minus2)? / r2)?;
|
||||
let o = Tensor::ones(d, xs.dtype(), xs.device())?;
|
||||
let x = Tensor::cat(&[o, (xs / rrd)?, (&x / rd)?], D::Minus1)?;
|
||||
|
||||
Ok(x)
|
||||
}))
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let eps = 1e-12;
|
||||
|
||||
let feature_map = self.taylor_expansion()?;
|
||||
|
||||
let (b, l, d) = xs.dims3()?;
|
||||
let q = xs.apply(&self.proj_q)?;
|
||||
let k = xs.apply(&self.proj_k)?;
|
||||
let v = xs.apply(&self.proj_v)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b, l, self.num_heads, self.feature_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
.reshape((b, l, self.num_heads, self.feature_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let v = v
|
||||
.reshape((b, l, self.num_heads, d / self.num_heads))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
let q = feature_map.forward(&q)?;
|
||||
let k = feature_map.forward(&k)?;
|
||||
|
||||
let y = if seqlen_offset > 0 {
|
||||
let (_b, _h, l, _d) = k.dims4()?;
|
||||
let q = q.unsqueeze(D::Minus2)?;
|
||||
let k = k.unsqueeze(D::Minus2)?;
|
||||
let v = v.unsqueeze(D::Minus1)?;
|
||||
let kn = k.narrow(D::Minus1, l - 1, 1)?;
|
||||
let vn = v.narrow(D::Minus1, l - 1, 1)?;
|
||||
|
||||
self.k_state = self.k_state.broadcast_add(&kn)?;
|
||||
self.kv_state = self.kv_state.broadcast_add(&kn.broadcast_mul(&vn)?)?;
|
||||
|
||||
let num = q.broadcast_mul(&self.kv_state)?.sum(D::Minus1)?;
|
||||
let den = (q.broadcast_mul(&self.k_state)?.sum(D::Minus1)? + eps)?;
|
||||
num.broadcast_div(&den)?
|
||||
} else {
|
||||
self.k_state = k.sum(2)?.unsqueeze(2)?.unsqueeze(3)?;
|
||||
self.kv_state = k
|
||||
.transpose(2, 3)?
|
||||
.matmul(&v)?
|
||||
.transpose(2, 3)?
|
||||
.unsqueeze(2)?;
|
||||
let aqk = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?;
|
||||
let tril = Tensor::tril2(l, aqk.dtype(), aqk.device())?;
|
||||
let aqk = aqk.broadcast_mul(&tril)?.matmul(&v)?;
|
||||
|
||||
let z = (1f64 / (q.mul(&k.cumsum(2)?)?.sum(D::Minus1)? + eps)?)?;
|
||||
aqk.broadcast_mul(&z.unsqueeze(D::Minus1)?)?
|
||||
};
|
||||
|
||||
let (b, h, l, d) = y.dims4()?;
|
||||
let y = y.permute((0, 2, 1, 3))?.reshape((b, l, h * d))?;
|
||||
let y = self.out_proj.forward(&y)?;
|
||||
|
||||
Ok(y)
|
||||
}
|
||||
}
|
||||
|
||||
// Rotary embeddings used in local attention.
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = 2048; // Hardcoded, missing from config.
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
// Local attention using a small sliding window.
|
||||
#[derive(Debug, Clone)]
|
||||
struct SlidingWindowAttention {
|
||||
wqkv: Linear,
|
||||
out_proj: Linear,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl SlidingWindowAttention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let num_heads = cfg.swa.num_heads;
|
||||
let head_dim = hidden_size / num_heads;
|
||||
let out_proj = linear_no_bias(hidden_size, hidden_size, vb.pp("out_proj"))?;
|
||||
let wqkv = linear_no_bias(hidden_size, hidden_size * 3, vb.pp("Wqkv"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
|
||||
Ok(Self {
|
||||
wqkv,
|
||||
out_proj,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let qkv = xs.apply(&self.wqkv)?;
|
||||
let qkv = qkv.reshape((b_sz, q_len, 3, (), self.head_dim))?;
|
||||
|
||||
let q = qkv.i((.., .., 0))?;
|
||||
let k = qkv.i((.., .., 1))?;
|
||||
let v = qkv.i((.., .., 2))?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (q, k) = self
|
||||
.rotary_emb
|
||||
.apply_rotary_emb_qkv(&q, &k, seqlen_offset)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let k = Tensor::cat(&[prev_k, &k], 2)?;
|
||||
let v = Tensor::cat(&[prev_v, &v], 2)?;
|
||||
(k, v)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let out = attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
.apply(&self.out_proj)?;
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
// The model layers use three types of mixers.
|
||||
#[derive(Debug, Clone)]
|
||||
enum SequenceMixer {
|
||||
Based(BasedConv),
|
||||
Linear(LinearAttention),
|
||||
Sliding(SlidingWindowAttention),
|
||||
}
|
||||
|
||||
impl SequenceMixer {
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::Based(b) => b.forward(xs, pos),
|
||||
Self::Linear(b) => b.forward(xs, pos),
|
||||
Self::Sliding(b) => b.forward(xs, attention_mask, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
mlp: MLP,
|
||||
norm1: RmsNorm,
|
||||
norm2: RmsNorm,
|
||||
mixer: SequenceMixer,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(layer_idx: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let norm1 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?;
|
||||
let norm2 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?;
|
||||
|
||||
let l_attn = cfg.alt_mixer_layers.contains(&layer_idx);
|
||||
let sw_attn = cfg.alt_mixer_2_layers.contains(&layer_idx);
|
||||
|
||||
let mixer = if l_attn {
|
||||
SequenceMixer::Linear(LinearAttention::new(cfg, vb.pp("mixer"))?)
|
||||
} else if sw_attn {
|
||||
SequenceMixer::Sliding(SlidingWindowAttention::new(cfg, vb.pp("mixer"))?)
|
||||
} else {
|
||||
SequenceMixer::Based(BasedConv::new(cfg, vb.pp("mixer"))?)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
mlp,
|
||||
norm1,
|
||||
norm2,
|
||||
mixer,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.norm1.forward(xs)?;
|
||||
let xs = self.mixer.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.norm2)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: super::with_tracing::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
sliding_window: usize,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vocab_size = cfg.vocab_size + (8 - cfg.vocab_size % 8) % 8;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, vocab_size, vb.pp("lm_head"))?;
|
||||
let embed_tokens = super::with_tracing::Embedding::from_weights(lm_head.weight().clone())?;
|
||||
let vb_m = vb.pp("transformer");
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(layer_idx, cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
sliding_window: cfg.swa.window_size,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let sliding_window = self.sliding_window / 2;
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
pub mod based;
|
||||
pub mod beit;
|
||||
pub mod bert;
|
||||
pub mod bigcode;
|
||||
|
Reference in New Issue
Block a user