Use the repo for the quantized phi model. (#954)

This commit is contained in:
Laurent Mazare
2023-09-24 16:30:26 +01:00
committed by GitHub
parent 0007ae9c11
commit f5069dd354

View File

@ -151,26 +151,29 @@ fn main() -> Result<()> {
args.revision, args.revision,
)); ));
let tokenizer_filename = repo.get("tokenizer.json")?; let tokenizer_filename = repo.get("tokenizer.json")?;
let filenames = match args.weight_file { let filename = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], Some(weight_file) => std::path::PathBuf::from(weight_file),
None => ["model.safetensors"] None => {
.iter() if args.quantized {
.map(|f| repo.get(f)) api.model("lmz/candle-quantized-phi".to_string())
.collect::<std::result::Result<Vec<_>, _>>()?, .get("model-q4k.gguf")?
} else {
repo.get("model.safetensors")?
}
}
}; };
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = Config::v1_5();
let (model, device) = if args.quantized { let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
let config = Config::v1_5();
let model = QMixFormer::new(&config, vb)?; let model = QMixFormer::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu) (Model::Quantized(model), Device::Cpu)
} else { } else {
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let config = Config::v1_5();
let model = MixFormer::new(&config, vb)?; let model = MixFormer::new(&config, vb)?;
(Model::MixFormer(model), device) (Model::MixFormer(model), device)
}; };