From 996a7f2e241b21dd5198b2f0fa8e01a9d8d20a11 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 26 Dec 2023 22:24:04 +0100 Subject: [PATCH] Rework the llama example config, add the solar model. (#1485) --- candle-examples/examples/llama/main.rs | 102 ++++++++----------------- 1 file changed, 33 insertions(+), 69 deletions(-) diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 4bf91d92..c2ed0e25 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -13,7 +13,7 @@ extern crate accelerate_src; extern crate intel_mkl_src; use anyhow::{bail, Error as E, Result}; -use clap::Parser; +use clap::{Parser, ValueEnum}; use candle::{DType, Tensor}; use candle_nn::VarBuilder; @@ -22,11 +22,19 @@ use hf_hub::{api::sync::Api, Repo, RepoType}; use std::io::Write; use candle_transformers::models::llama as model; -use model::{Config, Llama, LlamaConfig}; +use model::{Llama, LlamaConfig}; const EOS_TOKEN: &str = ""; const DEFAULT_PROMPT: &str = "My favorite theorem is "; +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + V1, + V2, + #[value(name = "solar-10.7b")] + Solar10_7B, +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -34,10 +42,6 @@ struct Args { #[arg(long)] cpu: bool, - /// Use npy instead of safetensors - #[arg(long)] - npy: Option, - /// The temperature used to generate samples. #[arg(long)] temperature: Option, @@ -76,17 +80,13 @@ struct Args { #[arg(long)] revision: Option, - #[arg(long)] - v1: bool, + /// The model size to use. + #[arg(long, default_value = "v2")] + which: Which, #[arg(long)] use_flash_attn: bool, - /// The folder name that contains safetensor weights and json files - /// (same structure as huggingface online) - #[arg(long)] - local_weights: Option, - /// Penalty to be applied for repeating tokens, 1. means no penalty. #[arg(long, default_value_t = 1.0)] repeat_penalty: f32, @@ -118,65 +118,29 @@ fn main() -> Result<()> { Some(dtype) => bail!("Unsupported dtype {dtype}"), None => DType::F16, }; - let (llama, tokenizer_filename, cache) = match args.npy { - Some(filename) => { - let config = if args.v1 { - Config::config_7b_v1(args.use_flash_attn) - } else { - Config::config_7b_v2(args.use_flash_attn) - }; - let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; - let vb = VarBuilder::from_npz(filename, dtype, &device)?; - let tokenizer = std::path::PathBuf::from("llama-tokenizer.json"); - (Llama::load(vb, &cache, &config)?, tokenizer, cache) - } - None => { - let api = Api::new()?; - let model_id = args.model_id.unwrap_or_else(|| { - if args.v1 { - "Narsil/amall-7b".to_string() - } else { - "meta-llama/Llama-2-7b-hf".to_string() - } - }); - println!("loading the model weights from {model_id}"); - let revision = args.revision.unwrap_or("main".to_string()); - let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + let (llama, tokenizer_filename, cache) = { + let api = Api::new()?; + let model_id = args.model_id.unwrap_or_else(|| match args.which { + Which::V1 => "Narsil/amall-7b".to_string(), + Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), + Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), + }); + println!("loading the model weights from {model_id}"); + let revision = args.revision.unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); - let tokenizer_filename = match &args.local_weights { - Some(path) => (path.to_owned() + "tokenizer.json").into(), - _ => api.get("tokenizer.json")?, - }; + let tokenizer_filename = api.get("tokenizer.json")?; + let config_filename = api.get("config.json")?; + let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let config = config.into_config(args.use_flash_attn); - let config_filename = match &args.local_weights { - Some(path) => (path.to_owned() + "config.json").into(), - _ => api.get("config.json")?, - }; - let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; - let config = config.into_config(args.use_flash_attn); + let filenames = + candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?; + println!("building the model"); + let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; - let mut filenames = vec![]; - for rfilename in [ - "model-00001-of-00002.safetensors", - "model-00002-of-00002.safetensors", - ] { - match &args.local_weights { - Some(path) => { - filenames.push((path.to_owned() + rfilename).into()); - } - _ => { - let filename = api.get(rfilename)?; - filenames.push(filename); - } - }; - } - - println!("building the model"); - let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; - - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - (Llama::load(vb, &cache, &config)?, tokenizer_filename, cache) - } + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + (Llama::load(vb, &cache, &config)?, tokenizer_filename, cache) }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);