mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use the new phi model by default. (#1589)
This commit is contained in:
@ -128,9 +128,8 @@ enum WhichModel {
|
|||||||
V1_5,
|
V1_5,
|
||||||
#[value(name = "2")]
|
#[value(name = "2")]
|
||||||
V2,
|
V2,
|
||||||
// TODO: Make this the default once it has been battle tested.
|
#[value(name = "2-old")]
|
||||||
#[value(name = "2-new")]
|
V2Old,
|
||||||
V2New,
|
|
||||||
PuffinPhiV2,
|
PuffinPhiV2,
|
||||||
PhiHermes,
|
PhiHermes,
|
||||||
}
|
}
|
||||||
@ -236,7 +235,7 @@ fn main() -> Result<()> {
|
|||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||||
WhichModel::V2 | WhichModel::V2New => "microsoft/phi-2".to_string(),
|
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
"lmz/candle-quantized-phi".to_string()
|
"lmz/candle-quantized-phi".to_string()
|
||||||
}
|
}
|
||||||
@ -251,10 +250,10 @@ fn main() -> Result<()> {
|
|||||||
"main".to_string()
|
"main".to_string()
|
||||||
} else {
|
} else {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => "refs/pr/2".to_string(),
|
WhichModel::V1 => "refs/pr/8".to_string(),
|
||||||
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
WhichModel::V1_5 => "refs/pr/73".to_string(),
|
||||||
WhichModel::V2 => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
||||||
WhichModel::V2New | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
"main".to_string()
|
"main".to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -265,7 +264,7 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer_filename = match args.tokenizer {
|
let tokenizer_filename = match args.tokenizer {
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => match args.model {
|
None => match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2New => {
|
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => {
|
||||||
repo.get("tokenizer.json")?
|
repo.get("tokenizer.json")?
|
||||||
}
|
}
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
@ -280,14 +279,14 @@ fn main() -> Result<()> {
|
|||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
|
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
|
||||||
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
|
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
|
||||||
WhichModel::V2 | WhichModel::V2New => vec![repo.get("model-v2-q4k.gguf")?],
|
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
||||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||||
WhichModel::V2 | WhichModel::V2New => candle_examples::hub_load_safetensors(
|
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors(
|
||||||
&repo,
|
&repo,
|
||||||
"model.safetensors.index.json",
|
"model.safetensors.index.json",
|
||||||
)?,
|
)?,
|
||||||
@ -304,35 +303,39 @@ fn main() -> Result<()> {
|
|||||||
let config = || match args.model {
|
let config = || match args.model {
|
||||||
WhichModel::V1 => Config::v1(),
|
WhichModel::V1 => Config::v1(),
|
||||||
WhichModel::V1_5 => Config::v1_5(),
|
WhichModel::V1_5 => Config::v1_5(),
|
||||||
WhichModel::V2 | WhichModel::V2New => Config::v2(),
|
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
||||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||||
};
|
};
|
||||||
let (model, device) = if args.model == WhichModel::V2New {
|
let (model, device) = if args.quantized {
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let config_filename = repo.get("config.json")?;
|
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
|
||||||
let config: PhiConfig = serde_json::from_str(&config)?;
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
|
||||||
let phi = Phi::new(&config, vb)?;
|
|
||||||
(Model::Phi(phi), device)
|
|
||||||
} else if args.quantized {
|
|
||||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
|
||||||
let config = config();
|
let config = config();
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
WhichModel::V2 | WhichModel::V2New => QMixFormer::new_v2(&config, vb)?,
|
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
|
||||||
_ => QMixFormer::new(&config, vb)?,
|
_ => QMixFormer::new(&config, vb)?,
|
||||||
};
|
};
|
||||||
(Model::Quantized(model), Device::Cpu)
|
(Model::Quantized(model), Device::Cpu)
|
||||||
} else {
|
} else {
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let config = config();
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
WhichModel::V2 | WhichModel::V2New => MixFormer::new_v2(&config, vb)?,
|
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
||||||
_ => MixFormer::new(&config, vb)?,
|
let config_filename = repo.get("config.json")?;
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let config: PhiConfig = serde_json::from_str(&config)?;
|
||||||
|
let phi = Phi::new(&config, vb)?;
|
||||||
|
Model::Phi(phi)
|
||||||
|
}
|
||||||
|
WhichModel::V2Old => {
|
||||||
|
let config = config();
|
||||||
|
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
|
||||||
|
}
|
||||||
|
WhichModel::PhiHermes | WhichModel::PuffinPhiV2 => {
|
||||||
|
let config = config();
|
||||||
|
Model::MixFormer(MixFormer::new(&config, vb)?)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
(Model::MixFormer(model), device)
|
(model, device)
|
||||||
};
|
};
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user