diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs new file mode 100644 index 00000000..fb147922 --- /dev/null +++ b/candle-examples/examples/mistral/main.rs @@ -0,0 +1,226 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle_transformers::models::mistral::{Config, Model}; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + device: Device, + tokenizer: Tokenizer, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer, + logits_processor, + repeat_penalty, + repeat_last_n, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + println!("starting the inference loop"); + print!("{prompt}"); + std::io::stdout().flush()?; + let mut tokens = self + .tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let mut new_tokens = vec![]; + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let start_pos = tokens.len().saturating_sub(context_size); + let ctxt = &tokens[start_pos..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input, start_pos)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + new_tokens.push(next_token); + let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; + print!("{token}"); + std::io::stdout().flush()?; + } + let dt = start_gen.elapsed(); + println!( + "\n{sample_len} tokens generated ({:.2} token/s)", + sample_len as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, default_value_t = 100)] + sample_len: usize, + + #[arg(long, default_value = "lmz/candle-mistral")] + model_id: String, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + weight_files: Option, + + #[arg(long)] + quantized: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let repo = api.repo(Repo::with_revision( + args.model_id, + RepoType::Model, + args.revision, + )); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => vec![ + repo.get("pytorch_model-00001-of-00002.safetensors")?, + repo.get("pytorch_model-00002-of-00002.safetensors")?, + ], + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let start = std::time::Instant::now(); + let config = Config::config_7b_v0_1(); + let device = candle_examples::device(args.cpu)?; + let dtype = if device.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + let model = Model::new(&config, vb)?; + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + Ok(()) +} diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 4ce89081..7db83ff1 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -1,7 +1,6 @@ -#![allow(unused)] -use crate::models::with_tracing::{linear_no_bias, Embedding as E, Linear}; +use crate::models::with_tracing::{linear_no_bias, Linear}; /// Mistral LLM, https://github.com/mistralai/mistral-src -use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; @@ -99,7 +98,7 @@ impl RotaryEmbedding { k: &Tensor, seqlen_offset: usize, ) -> Result<(Tensor, Tensor)> { - let (b_sz, seq_len, h, n_embd) = q.dims4()?; + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) @@ -240,7 +239,7 @@ impl Attention { let attn_weights = match attention_mask { None => attn_weights, - Some(mask) => (attn_weights + mask)?, + Some(mask) => attn_weights.broadcast_add(mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; let attn_output = attn_weights.matmul(&value_states)?; @@ -290,7 +289,7 @@ impl DecoderLayer { let xs = (xs + residual)?; let residual = &xs; let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; - Ok(xs) + residual + xs } } @@ -300,22 +299,24 @@ pub struct Model { layers: Vec, norm: RmsNorm, lm_head: Linear, + #[allow(unused)] sliding_window: usize, device: Device, } impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("model"); let embed_tokens = - candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?; - let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb.device())?); + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb_m.device())?); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); - let vb_l = vb.pp("layers"); + let vb_l = vb_m.pp("layers"); for layer_idx in 0..cfg.num_hidden_layers { let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?; + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; Ok(Self { embed_tokens, @@ -359,6 +360,8 @@ impl Model { for layer in self.layers.iter_mut() { xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? } - xs.apply(&self.norm)?.apply(&self.lm_head) + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) } }