Rebase after phi2 merge + fix replit default to CPU.

This commit is contained in:
Nicolas Patry
2024-01-15 17:52:49 +01:00
parent b2db5adf82
commit 3dbf65ef20
2 changed files with 6 additions and 12 deletions

View File

@ -314,16 +314,14 @@ fn main() -> Result<()> {
&filenames[0], &filenames[0],
&device, &device,
)?; )?;
println!("Loaded vb");
let model = match args.model { let model = match args.model {
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?, WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?, _ => QMixFormer::new(&config, vb)?,
}; };
println!("Loaded model");
Model::Quantized(model) Model::Quantized(model)
} else { } else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = match args.model { match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => { WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
let config_filename = repo.get("config.json")?; let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?; let config = std::fs::read_to_string(config_filename)?;
@ -339,8 +337,7 @@ fn main() -> Result<()> {
let config = config(); let config = config();
Model::MixFormer(MixFormer::new(&config, vb)?) Model::MixFormer(MixFormer::new(&config, vb)?)
} }
}; }
model
}; };
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());

View File

@ -236,18 +236,15 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let device = Device::Cpu; let device = candle_examples::device(args.cpu)?;
let config = Config::replit_code_v1_5_3b(); let config = Config::replit_code_v1_5_3b();
let (model, device) = if args.quantized { let model = if args.quantized {
let vb = let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?; candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?); Model::Q(Q::new(&config, vb.pp("transformer"))?)
(model, Device::Cpu)
} else { } else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let model = Model::M(M::new(&config, vb.pp("transformer"))?); Model::M(M::new(&config, vb.pp("transformer"))?)
(model, device)
}; };
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());