mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Rebase after phi2 merge + fix replit default to CPU.
This commit is contained in:
@ -314,16 +314,14 @@ fn main() -> Result<()> {
|
||||
&filenames[0],
|
||||
&device,
|
||||
)?;
|
||||
println!("Loaded vb");
|
||||
let model = match args.model {
|
||||
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
|
||||
_ => QMixFormer::new(&config, vb)?,
|
||||
};
|
||||
println!("Loaded model");
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = match args.model {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
@ -339,8 +337,7 @@ fn main() -> Result<()> {
|
||||
let config = config();
|
||||
Model::MixFormer(MixFormer::new(&config, vb)?)
|
||||
}
|
||||
};
|
||||
model
|
||||
}
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
Reference in New Issue
Block a user