Using the real config from the hub when available.

This commit is contained in:
Nicolas Patry
2023-08-16 10:29:46 +02:00
parent 965597a873
commit 76804730c6
2 changed files with 75 additions and 43 deletions

View File

@ -22,7 +22,7 @@ use hf_hub::api::sync::Api;
use std::io::Write;
mod model;
use model::{Config, Llama};
use model::{Config, Llama, LlamaConfig};
const EOS_TOKEN: &str = "</s>";
const MAX_SEQ_LEN: usize = 4096;
@ -98,18 +98,18 @@ fn main() -> Result<()> {
};
let device = candle_examples::device(args.cpu)?;
let config = if args.v1 {
Config::config_7b_v1(args.use_flash_attn)
} else {
Config::config_7b_v2(args.use_flash_attn)
};
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let (llama, tokenizer_filename) = match args.npy {
let (llama, tokenizer_filename, cache) = match args.npy {
Some(filename) => {
let config = if args.v1 {
Config::config_7b_v1(args.use_flash_attn)
} else {
Config::config_7b_v2(args.use_flash_attn)
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
(Llama::load(vb, &cache, &config)?, tokenizer)
(Llama::load(vb, &cache, &config)?, tokenizer, cache)
}
None => {
let api = Api::new()?;
@ -128,6 +128,13 @@ fn main() -> Result<()> {
_ => api.get("tokenizer.json")?,
};
let config_filename = match &args.local_weights {
Some(path) => (path.to_owned() + "config.json").into(),
_ => api.get("config.json")?,
};
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(args.use_flash_attn);
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
@ -153,9 +160,10 @@ fn main() -> Result<()> {
.iter()
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
(Llama::load(vb, &cache, &config)?, tokenizer_filename)
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
}
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;