mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Rework the llama example config, add the solar model. (#1485)
This commit is contained in:
@ -13,7 +13,7 @@ extern crate accelerate_src;
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::{bail, Error as E, Result};
|
use anyhow::{bail, Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle::{DType, Tensor};
|
use candle::{DType, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
@ -22,11 +22,19 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
|||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
use candle_transformers::models::llama as model;
|
use candle_transformers::models::llama as model;
|
||||||
use model::{Config, Llama, LlamaConfig};
|
use model::{Llama, LlamaConfig};
|
||||||
|
|
||||||
const EOS_TOKEN: &str = "</s>";
|
const EOS_TOKEN: &str = "</s>";
|
||||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
V1,
|
||||||
|
V2,
|
||||||
|
#[value(name = "solar-10.7b")]
|
||||||
|
Solar10_7B,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -34,10 +42,6 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
/// Use npy instead of safetensors
|
|
||||||
#[arg(long)]
|
|
||||||
npy: Option<String>,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
@ -76,17 +80,13 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
/// The model size to use.
|
||||||
v1: bool,
|
#[arg(long, default_value = "v2")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
|
||||||
/// The folder name that contains safetensor weights and json files
|
|
||||||
/// (same structure as huggingface online)
|
|
||||||
#[arg(long)]
|
|
||||||
local_weights: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
#[arg(long, default_value_t = 1.0)]
|
#[arg(long, default_value_t = 1.0)]
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
@ -118,65 +118,29 @@ fn main() -> Result<()> {
|
|||||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||||
None => DType::F16,
|
None => DType::F16,
|
||||||
};
|
};
|
||||||
let (llama, tokenizer_filename, cache) = match args.npy {
|
let (llama, tokenizer_filename, cache) = {
|
||||||
Some(filename) => {
|
let api = Api::new()?;
|
||||||
let config = if args.v1 {
|
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||||
Config::config_7b_v1(args.use_flash_attn)
|
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||||
} else {
|
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||||
Config::config_7b_v2(args.use_flash_attn)
|
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||||
};
|
});
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
println!("loading the model weights from {model_id}");
|
||||||
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
|
let revision = args.revision.unwrap_or("main".to_string());
|
||||||
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
(Llama::load(vb, &cache, &config)?, tokenizer, cache)
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
let api = Api::new()?;
|
|
||||||
let model_id = args.model_id.unwrap_or_else(|| {
|
|
||||||
if args.v1 {
|
|
||||||
"Narsil/amall-7b".to_string()
|
|
||||||
} else {
|
|
||||||
"meta-llama/Llama-2-7b-hf".to_string()
|
|
||||||
}
|
|
||||||
});
|
|
||||||
println!("loading the model weights from {model_id}");
|
|
||||||
let revision = args.revision.unwrap_or("main".to_string());
|
|
||||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
|
||||||
|
|
||||||
let tokenizer_filename = match &args.local_weights {
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
let config_filename = api.get("config.json")?;
|
||||||
_ => api.get("tokenizer.json")?,
|
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
};
|
let config = config.into_config(args.use_flash_attn);
|
||||||
|
|
||||||
let config_filename = match &args.local_weights {
|
let filenames =
|
||||||
Some(path) => (path.to_owned() + "config.json").into(),
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||||
_ => api.get("config.json")?,
|
println!("building the model");
|
||||||
};
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
|
||||||
let config = config.into_config(args.use_flash_attn);
|
|
||||||
|
|
||||||
let mut filenames = vec![];
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
for rfilename in [
|
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||||
"model-00001-of-00002.safetensors",
|
|
||||||
"model-00002-of-00002.safetensors",
|
|
||||||
] {
|
|
||||||
match &args.local_weights {
|
|
||||||
Some(path) => {
|
|
||||||
filenames.push((path.to_owned() + rfilename).into());
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let filename = api.get(rfilename)?;
|
|
||||||
filenames.push(filename);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("building the model");
|
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
|
||||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||||
|
Reference in New Issue
Block a user