diff --git a/README.md b/README.md index 90344b34..9bfa30d8 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ We also provide a some command line based examples using state of the art models - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM pre-trained on 1T tokens of English and code datasets. Also supports StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants. -- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal +- [Mamba](./candle-examples/examples/mamba/): an inference only implementation of the Mamba state space model. - [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with better performance than all publicly available 13b models as of 2023-09-28. @@ -186,7 +186,7 @@ If you have an addition to this list, please submit a pull request. - Falcon. - StarCoder. - Phi 1, 1.5, and 2. - - Minimal Mamba + - Mamba, Minimal Mamba - Mistral 7b v0.1. - Mixtral 8x7b v0.1. - StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B. diff --git a/candle-examples/examples/mamba-minimal/README.md b/candle-examples/examples/mamba-minimal/README.md index 0ce42123..46479828 100644 --- a/candle-examples/examples/mamba-minimal/README.md +++ b/candle-examples/examples/mamba-minimal/README.md @@ -2,6 +2,9 @@ This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal). +Compared to the mamba example, this version can handle training but is much +slower. + ## Running the example ```bash diff --git a/candle-examples/examples/mamba/README.md b/candle-examples/examples/mamba/README.md new file mode 100644 index 00000000..507434a1 --- /dev/null +++ b/candle-examples/examples/mamba/README.md @@ -0,0 +1,17 @@ +# candle-mamba: Mamba implementation + +Candle implementation of *Mamba* [1] inference only. Mamba is an alternative to +the transformer architecture. It leverages State Space Models (SSMs) with the +goal of being computationally efficient on long sequences. The implementation is +based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs). + +- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752). + +Compared to the mamba-minimal example, this version is far more efficient but +would only work for inference. +## Running the example + +```bash +$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the" +``` + diff --git a/candle-examples/examples/mamba/main.rs b/candle-examples/examples/mamba/main.rs new file mode 100644 index 00000000..4802f960 --- /dev/null +++ b/candle-examples/examples/mamba/main.rs @@ -0,0 +1,299 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::{Parser, ValueEnum}; + +use candle_transformers::models::mamba::{Config, Model, State}; + +use candle::{DType, Device, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + config: Config, + device: Device, + tokenizer: TokenOutputStream, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + config: Config, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + config, + tokenizer: TokenOutputStream::new(tokenizer), + logits_processor, + repeat_penalty, + repeat_last_n, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_token("<|endoftext|>") { + Some(token) => token, + None => anyhow::bail!("cannot find the token"), + }; + let mut state = State::new(1, &self.config, &self.device)?; + let mut next_logits = None; + for &t in tokens.iter() { + let input = Tensor::new(&[t], &self.device)?; + let logits = self.model.forward(&input, &mut state)?; + next_logits = Some(logits); + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; + + let start_gen = std::time::Instant::now(); + for _ in 0..sample_len { + let logits = match next_logits.as_ref() { + Some(logits) => logits, + None => anyhow::bail!("cannot work on an empty prompt"), + }; + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + let input = Tensor::new(&[next_token], &self.device)?; + next_logits = Some(self.model.forward(&input, &mut state)?) + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)] +enum Which { + Mamba130m, + Mamba370m, + Mamba790m, + Mamba1_4b, + Mamba2_8b, + Mamba2_8bSlimPj, +} + +impl std::fmt::Display for Which { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl Which { + fn model_id(&self) -> &'static str { + match self { + Self::Mamba130m => "state-spaces/mamba-130m", + Self::Mamba370m => "state-spaces/mamba-370m", + Self::Mamba790m => "state-spaces/mamba-790m", + Self::Mamba1_4b => "state-spaces/mamba-1.4b", + Self::Mamba2_8b => "state-spaces/mamba-2.8b", + Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'", + } + } + + fn revision(&self) -> &'static str { + match self { + Self::Mamba130m + | Self::Mamba370m + | Self::Mamba790m + | Self::Mamba1_4b + | Self::Mamba2_8bSlimPj => "refs/pr/1", + Self::Mamba2_8b => "refs/pr/4", + } + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 5000)] + sample_len: usize, + + #[arg(long, default_value = "mamba130m")] + which: Which, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + weight_files: Option, + + #[arg(long)] + config_file: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let repo = api.repo(Repo::with_revision( + args.model_id + .unwrap_or_else(|| args.which.model_id().to_string()), + RepoType::Model, + args.revision + .unwrap_or_else(|| args.which.revision().to_string()), + )); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => api + .model("EleutherAI/gpt-neox-20b".to_string()) + .get("tokenizer.json")?, + }; + let config_filename = match args.config_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("config.json")?, + }; + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => { + vec![repo.get("model.safetensors")?] + } + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let start = std::time::Instant::now(); + let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let device = candle_examples::device(args.cpu)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let model = Model::new(&config, vb.pp("backbone"))?; + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + config, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + Ok(()) +} diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs new file mode 100644 index 00000000..da254bd1 --- /dev/null +++ b/candle-transformers/src/models/mamba.rs @@ -0,0 +1,211 @@ +#![allow(unused)] +/// A fast implementation of mamba for inference only. +/// This is based on: https://github.com/LaurentMazare/mamba.rs +use crate::models::with_tracing::{linear, linear_no_bias, Linear}; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{RmsNorm, VarBuilder}; + +const D_CONV: usize = 4; +const D_STATE: usize = 16; + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct Config { + d_model: usize, + n_layer: usize, + vocab_size: usize, + pad_vocab_size_multiple: usize, +} + +impl Config { + fn vocab_size(&self) -> usize { + let pad = self.pad_vocab_size_multiple; + (self.vocab_size + pad - 1) / pad * pad + } + + fn dt_rank(&self) -> usize { + (self.d_model + 15) / 16 + } + + fn d_inner(&self) -> usize { + self.d_model * 2 + } +} + +pub struct State { + hs: Vec, + prev_xs: Vec<[Tensor; D_CONV]>, + pos: usize, +} + +impl State { + pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result { + let mut hs = Vec::with_capacity(cfg.n_layer); + let mut prev_xs = Vec::with_capacity(cfg.n_layer); + for _i in 0..cfg.n_layer { + let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?; + let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?; + hs.push(h); + prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]); + } + Ok(Self { + hs, + prev_xs, + pos: 0, + }) + } +} + +#[derive(Clone, Debug)] +pub struct MambaBlock { + in_proj: Linear, + conv1d_bias: Tensor, + conv1d_weights: [Tensor; D_CONV], + x_proj: Linear, + dt_proj: Linear, + a_log: Tensor, + d: Tensor, + out_proj: Linear, + dt_rank: usize, + layer_index: usize, + d_inner: usize, +} + +impl MambaBlock { + pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result { + let d_inner = cfg.d_inner(); + let dt_rank = cfg.dt_rank(); + let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?; + let x_proj = linear_no_bias(d_inner, dt_rank + D_STATE * 2, vb.pp("x_proj"))?; + let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?; + let a_log = vb.get((d_inner, D_STATE), "A_log")?; + let d = vb.get(d_inner, "D")?; + let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?; + let conv1d_bias = vb.get(d_inner, "conv1d.bias")?; + let conv1d_weight = vb.get((d_inner, 1, D_CONV), "conv1d.weight")?; + let conv1d_weights = [ + conv1d_weight.i((.., 0, 0))?, + conv1d_weight.i((.., 0, 1))?, + conv1d_weight.i((.., 0, 2))?, + conv1d_weight.i((.., 0, 3))?, + ]; + Ok(Self { + in_proj, + conv1d_bias, + conv1d_weights, + x_proj, + dt_proj, + a_log, + d, + out_proj, + dt_rank, + layer_index, + d_inner, + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let (b_sz, _dim) = xs.dims2()?; + let li = self.layer_index; + let mut xs = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?; + let proj_for_silu = xs.remove(1); + state.prev_xs[li][state.pos % D_CONV] = xs.remove(0); + let mut proj_for_conv = self.conv1d_bias.broadcast_as((b_sz, self.d_inner))?; + for d_c in 0..D_CONV { + proj_for_conv = (proj_for_conv + + self.conv1d_weights[d_c] + .broadcast_mul(&state.prev_xs[li][(d_c + 1 + state.pos) % D_CONV])?)?; + } + let proj_for_conv = candle_nn::ops::silu(&proj_for_conv)?; + // SSM + Selection, we're doing inference here so only need the last step of + // the sequence. + // Algorithm 3.2 on page 6, https://arxiv.org/pdf/2312.00752.pdf + + let x_proj = self.x_proj.forward(&proj_for_conv)?; + let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?; + let b = x_proj.narrow(D::Minus1, self.dt_rank, D_STATE)?; + let c = x_proj.narrow(D::Minus1, self.dt_rank + D_STATE, D_STATE)?; + + let delta = delta.apply(&self.dt_proj)?; + // softplus + let delta = (delta.exp()? + 1.)?.log()?; + let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?; + let d = self.d.to_dtype(candle::DType::F32)?; + + // Selective scan part + // Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t + let delta = delta + .unsqueeze(D::Minus1)? + .broadcast_as((b_sz, self.d_inner, D_STATE))?; + let a = a.broadcast_as((b_sz, self.d_inner, D_STATE))?; + let b = b.broadcast_as((b_sz, self.d_inner, D_STATE))?; + let proj_for_conv_b = + proj_for_conv + .unsqueeze(D::Minus1)? + .broadcast_as((b_sz, self.d_inner, D_STATE))?; + state.hs[li] = ((&state.hs[li] * (&delta * &a)?.exp()?)? + &delta * &b * &proj_for_conv_b)?; + let ss = (state.hs[li] + .matmul(&c.unsqueeze(D::Minus1)?)? + .squeeze(D::Minus1)? + + proj_for_conv.broadcast_mul(&d)?)?; + + let ys = (ss * candle_nn::ops::silu(&proj_for_silu))?; + ys.apply(&self.out_proj) + } +} + +#[derive(Clone, Debug)] +pub struct ResidualBlock { + mixer: MambaBlock, + norm: RmsNorm, +} + +impl ResidualBlock { + pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result { + let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?; + let mixer = MambaBlock::new(layer_index, cfg, vb.pp("mixer"))?; + Ok(Self { mixer, norm }) + } + + fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + self.mixer.forward(&xs.apply(&self.norm)?, state)? + xs + } +} + +// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56 +#[derive(Clone, Debug)] +pub struct Model { + embedding: candle_nn::Embedding, + layers: Vec, + norm_f: RmsNorm, + lm_head: Linear, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?; + let mut layers = Vec::with_capacity(cfg.n_layer); + let vb_l = vb.pp("layers"); + for layer_idx in 0..cfg.n_layer { + let layer = ResidualBlock::new(layer_idx, cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?; + let lm_head = Linear::from_weights(embedding.embeddings().clone(), None); + Ok(Self { + embedding, + layers, + norm_f, + lm_head, + }) + } + + pub fn forward(&self, input_ids: &Tensor, state: &mut State) -> Result { + let _b_size = input_ids.dims1()?; + let mut xs = self.embedding.forward(input_ids)?; + for layer in self.layers.iter() { + xs = layer.forward(&xs, state)? + } + state.pos += 1; + xs.apply(&self.norm_f)?.apply(&self.lm_head) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index f3782fff..769fd650 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -13,6 +13,7 @@ pub mod jina_bert; pub mod llama; pub mod llama2_c; pub mod llama2_c_weights; +pub mod mamba; pub mod marian; pub mod mistral; pub mod mixformer;