mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Use F16 for moondream on cuda. (#2013)
This commit is contained in:
@ -283,6 +283,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config = moondream::Config::v2();
|
||||
let dtype = if device.is_cuda() && !args.quantized {
|
||||
DType::F16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let model = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&model_file,
|
||||
@ -291,15 +296,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
let model = quantized_moondream::Model::new(&config, vb)?;
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let model = moondream::Model::new(&config, vb)?;
|
||||
Model::Moondream(model)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
let image = load_image(args.image)?
|
||||
.to_device(&device)?
|
||||
.to_dtype(dtype)?;
|
||||
let image_embeds = image.unsqueeze(0)?;
|
||||
let image_embeds = match model {
|
||||
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
|
||||
|
Reference in New Issue
Block a user