mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Rebase after phi2 merge + fix replit default to CPU.
This commit is contained in:
@ -314,16 +314,14 @@ fn main() -> Result<()> {
|
|||||||
&filenames[0],
|
&filenames[0],
|
||||||
&device,
|
&device,
|
||||||
)?;
|
)?;
|
||||||
println!("Loaded vb");
|
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
|
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
|
||||||
_ => QMixFormer::new(&config, vb)?,
|
_ => QMixFormer::new(&config, vb)?,
|
||||||
};
|
};
|
||||||
println!("Loaded model");
|
|
||||||
Model::Quantized(model)
|
Model::Quantized(model)
|
||||||
} else {
|
} else {
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
let model = match args.model {
|
match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
||||||
let config_filename = repo.get("config.json")?;
|
let config_filename = repo.get("config.json")?;
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
@ -339,8 +337,7 @@ fn main() -> Result<()> {
|
|||||||
let config = config();
|
let config = config();
|
||||||
Model::MixFormer(MixFormer::new(&config, vb)?)
|
Model::MixFormer(MixFormer::new(&config, vb)?)
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
model
|
|
||||||
};
|
};
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
@ -236,18 +236,15 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let device = Device::Cpu;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let config = Config::replit_code_v1_5_3b();
|
let config = Config::replit_code_v1_5_3b();
|
||||||
let (model, device) = if args.quantized {
|
let model = if args.quantized {
|
||||||
let vb =
|
let vb =
|
||||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
|
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
|
||||||
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
|
Model::Q(Q::new(&config, vb.pp("transformer"))?)
|
||||||
(model, Device::Cpu)
|
|
||||||
} else {
|
} else {
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||||
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
|
Model::M(M::new(&config, vb.pp("transformer"))?)
|
||||||
(model, device)
|
|
||||||
};
|
};
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user