mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Rework the llama example config, add the solar model. (#1485)
This commit is contained in:
@ -13,7 +13,7 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{bail, Error as E, Result};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
@ -22,11 +22,19 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
use candle_transformers::models::llama as model;
|
||||
use model::{Config, Llama, LlamaConfig};
|
||||
use model::{Llama, LlamaConfig};
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
V1,
|
||||
V2,
|
||||
#[value(name = "solar-10.7b")]
|
||||
Solar10_7B,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -34,10 +42,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use npy instead of safetensors
|
||||
#[arg(long)]
|
||||
npy: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
@ -76,17 +80,13 @@ struct Args {
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
v1: bool,
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "v2")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// The folder name that contains safetensor weights and json files
|
||||
/// (same structure as huggingface online)
|
||||
#[arg(long)]
|
||||
local_weights: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
@ -118,65 +118,29 @@ fn main() -> Result<()> {
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
let (llama, tokenizer_filename, cache) = match args.npy {
|
||||
Some(filename) => {
|
||||
let config = if args.v1 {
|
||||
Config::config_7b_v1(args.use_flash_attn)
|
||||
} else {
|
||||
Config::config_7b_v2(args.use_flash_attn)
|
||||
};
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
|
||||
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer, cache)
|
||||
}
|
||||
None => {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| {
|
||||
if args.v1 {
|
||||
"Narsil/amall-7b".to_string()
|
||||
} else {
|
||||
"meta-llama/Llama-2-7b-hf".to_string()
|
||||
}
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let (llama, tokenizer_filename, cache) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
|
||||
let tokenizer_filename = match &args.local_weights {
|
||||
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
||||
_ => api.get("tokenizer.json")?,
|
||||
};
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let config_filename = api.get("config.json")?;
|
||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
|
||||
let config_filename = match &args.local_weights {
|
||||
Some(path) => (path.to_owned() + "config.json").into(),
|
||||
_ => api.get("config.json")?,
|
||||
};
|
||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
let filenames =
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||
println!("building the model");
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
match &args.local_weights {
|
||||
Some(path) => {
|
||||
filenames.push((path.to_owned() + rfilename).into());
|
||||
}
|
||||
_ => {
|
||||
let filename = api.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
println!("building the model");
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
}
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||
|
Reference in New Issue
Block a user