mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Make the dtype configurable for phi. (#2133)
This commit is contained in:
@ -209,6 +209,10 @@ struct Args {
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The dtype to be used for running the model, e.g. f32, bf16, or f16.
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -345,10 +349,15 @@ fn main() -> Result<()> {
|
||||
};
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let dtype = if args.model == WhichModel::V3 && device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
let dtype = match args.dtype {
|
||||
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
|
||||
None => {
|
||||
if args.model == WhichModel::V3 && device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
}
|
||||
}
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
match args.model {
|
||||
|
Reference in New Issue
Block a user