mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the mistral example. (#984)
* Add the mistral example. * Use the two model files. * Adjust the dtype. * Tweak the weight paths. * Remove the end of text token. * Get the mistral model to generate some text.
This commit is contained in:
226
candle-examples/examples/mistral/main.rs
Normal file
226
candle-examples/examples/mistral/main.rs
Normal file
@ -0,0 +1,226 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::mistral::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{sample_len} tokens generated ({:.2} token/s)",
|
||||
sample_len as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "lmz/candle-mistral")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => vec![
|
||||
repo.get("pytorch_model-00001-of-00002.safetensors")?,
|
||||
repo.get("pytorch_model-00002-of-00002.safetensors")?,
|
||||
],
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::config_7b_v0_1();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
#![allow(unused)]
|
||||
use crate::models::with_tracing::{linear_no_bias, Embedding as E, Linear};
|
||||
use crate::models::with_tracing::{linear_no_bias, Linear};
|
||||
/// Mistral LLM, https://github.com/mistralai/mistral-src
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -99,7 +98,7 @@ impl RotaryEmbedding {
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (b_sz, seq_len, h, n_embd) = q.dims4()?;
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
@ -240,7 +239,7 @@ impl Attention {
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => (attn_weights + mask)?,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&value_states)?;
|
||||
@ -290,7 +289,7 @@ impl DecoderLayer {
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||
Ok(xs)
|
||||
residual + xs
|
||||
}
|
||||
}
|
||||
|
||||
@ -300,22 +299,24 @@ pub struct Model {
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
#[allow(unused)]
|
||||
sliding_window: usize,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb.device())?);
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?;
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
@ -359,6 +360,8 @@ impl Model {
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.apply(&self.norm)?.apply(&self.lm_head)
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user