Make the dtype configurable for phi. (#2133)

This commit is contained in:
Laurent Mazare
2024-04-27 21:32:49 +02:00
committed by GitHub
parent 96a48e5cc4
commit 3b429f3023

View File

@ -209,6 +209,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The dtype to be used for running the model, e.g. f32, bf16, or f16.
#[arg(long)]
dtype: Option<String>,
}
fn main() -> Result<()> {
@ -345,10 +349,15 @@ fn main() -> Result<()> {
};
Model::Quantized(model)
} else {
let dtype = if args.model == WhichModel::V3 && device.is_cuda() {
DType::BF16
} else {
DType::F32
let dtype = match args.dtype {
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
None => {
if args.model == WhichModel::V3 && device.is_cuda() {
DType::BF16
} else {
DType::F32
}
}
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
match args.model {