Rebase after phi2 merge + fix replit default to CPU.

This commit is contained in:
Nicolas Patry
2024-01-15 17:52:49 +01:00
parent b2db5adf82
commit 3dbf65ef20
2 changed files with 6 additions and 12 deletions

View File

@ -314,16 +314,14 @@ fn main() -> Result<()> {
&filenames[0],
&device,
)?;
println!("Loaded vb");
let model = match args.model {
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?,
};
println!("Loaded model");
Model::Quantized(model)
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = match args.model {
match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
@ -339,8 +337,7 @@ fn main() -> Result<()> {
let config = config();
Model::MixFormer(MixFormer::new(&config, vb)?)
}
};
model
}
};
println!("loaded the model in {:?}", start.elapsed());

View File

@ -236,18 +236,15 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = Device::Cpu;
let device = candle_examples::device(args.cpu)?;
let config = Config::replit_code_v1_5_3b();
let (model, device) = if args.quantized {
let model = if args.quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
(model, Device::Cpu)
Model::Q(Q::new(&config, vb.pp("transformer"))?)
} else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
(model, device)
Model::M(M::new(&config, vb.pp("transformer"))?)
};
println!("loaded the model in {:?}", start.elapsed());