Use F16 for moondream on cuda. (#2013)

This commit is contained in:
Laurent Mazare
2024-04-04 23:30:10 +02:00
committed by GitHub
parent c5626b8271
commit c87381fc96
2 changed files with 17 additions and 8 deletions

View File

@ -283,6 +283,11 @@ async fn main() -> anyhow::Result<()> {
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let config = moondream::Config::v2();
let dtype = if device.is_cuda() && !args.quantized {
DType::F16
} else {
DType::F32
};
let model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&model_file,
@ -291,15 +296,16 @@ async fn main() -> anyhow::Result<()> {
let model = quantized_moondream::Model::new(&config, vb)?;
Model::Quantized(model)
} else {
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let model = moondream::Model::new(&config, vb)?;
Model::Moondream(model)
};
println!("loaded the model in {:?}", start.elapsed());
let start = std::time::Instant::now();
let image = load_image(args.image)?.to_device(&device)?;
let image = load_image(args.image)?
.to_device(&device)?
.to_dtype(dtype)?;
let image_embeds = image.unsqueeze(0)?;
let image_embeds = match model {
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,