diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 9bd7e8da..5cdbb4b9 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -151,26 +151,29 @@ fn main() -> Result<()> { args.revision, )); let tokenizer_filename = repo.get("tokenizer.json")?; - let filenames = match args.weight_file { - Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], - None => ["model.safetensors"] - .iter() - .map(|f| repo.get(f)) - .collect::, _>>()?, + let filename = match args.weight_file { + Some(weight_file) => std::path::PathBuf::from(weight_file), + None => { + if args.quantized { + api.model("lmz/candle-quantized-phi".to_string()) + .get("model-q4k.gguf")? + } else { + repo.get("model.safetensors")? + } + } }; println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); + let config = Config::v1_5(); let (model, device) = if args.quantized { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; - let config = Config::v1_5(); + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; let model = QMixFormer::new(&config, vb)?; (Model::Quantized(model), Device::Cpu) } else { let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let config = Config::v1_5(); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; let model = MixFormer::new(&config, vb)?; (Model::MixFormer(model), device) };