diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs index bcc21337..dfd83037 100644 --- a/candle-examples/examples/moondream/main.rs +++ b/candle-examples/examples/moondream/main.rs @@ -283,6 +283,11 @@ async fn main() -> anyhow::Result<()> { let start = std::time::Instant::now(); let device = candle_examples::device(args.cpu)?; let config = moondream::Config::v2(); + let dtype = if device.is_cuda() && !args.quantized { + DType::F16 + } else { + DType::F32 + }; let model = if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( &model_file, @@ -291,15 +296,16 @@ async fn main() -> anyhow::Result<()> { let model = quantized_moondream::Model::new(&config, vb)?; Model::Quantized(model) } else { - let vb = - unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; let model = moondream::Model::new(&config, vb)?; Model::Moondream(model) }; println!("loaded the model in {:?}", start.elapsed()); let start = std::time::Instant::now(); - let image = load_image(args.image)?.to_device(&device)?; + let image = load_image(args.image)? + .to_device(&device)? + .to_dtype(dtype)?; let image_embeds = image.unsqueeze(0)?; let image_embeds = match model { Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?, diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 65a1665a..de15c3a5 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -135,7 +135,9 @@ fn get_mask(size: usize, device: &Device) -> Result { fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let on_true = Tensor::new(on_true, on_false.device())? + .to_dtype(on_false.dtype())? + .broadcast_as(shape.dims())?; let m = mask.where_cond(&on_true, on_false)?; Ok(m) } @@ -147,7 +149,7 @@ struct RotaryEmbedding { } impl RotaryEmbedding { - fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result { + fn new(dim: usize, max_seq_len: usize, dtype: DType, dev: &Device) -> Result { let inv_freq: Vec<_> = (0..dim) .step_by(2) .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) @@ -159,8 +161,8 @@ impl RotaryEmbedding { .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, + sin: freqs.sin()?.to_dtype(dtype)?, + cos: freqs.cos()?.to_dtype(dtype)?, }) } @@ -274,7 +276,8 @@ impl MHA { let op_size = cfg.n_embd; let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?; let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?; - let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?; + let rotary_emb = + RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.dtype(), vb.device())?; let softmax_scale = 1f64 / (head_dim as f64).sqrt(); Ok(Self { wqkv,