diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs index 2ec04256..3e0f6d57 100644 --- a/candle-examples/examples/moondream/main.rs +++ b/candle-examples/examples/moondream/main.rs @@ -155,6 +155,18 @@ struct Args { /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 64)] repeat_last_n: usize, + + #[arg(long, default_value = "vikhyatk/moondream2")] + model_id: String, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + model_file: Option, + + #[arg(long)] + tokenizer_file: Option, } /// Loads an image from disk using the image crate, this returns a tensor with shape @@ -204,9 +216,19 @@ async fn main() -> anyhow::Result<()> { let start = std::time::Instant::now(); let api = hf_hub::api::tokio::Api::new()?; - let repo = api.model("vikhyatk/moondream2".to_string()); - let model_file = repo.get("model.safetensors").await?; - let tokenizer = repo.get("tokenizer.json").await?; + let repo = api.repo(hf_hub::Repo::with_revision( + args.model_id, + hf_hub::RepoType::Model, + args.revision, + )); + let model_file = match args.model_file { + Some(m) => m.into(), + None => repo.get("model.safetensors").await?, + }; + let tokenizer = match args.tokenizer_file { + Some(m) => m.into(), + None => repo.get("tokenizer.json").await?, + }; println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; diff --git a/candle-transformers/src/models/moondream.rs b/candle-transformers/src/models/moondream.rs index 1172bf71..c36052c6 100644 --- a/candle-transformers/src/models/moondream.rs +++ b/candle-transformers/src/models/moondream.rs @@ -19,11 +19,8 @@ impl Config { fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { let dim = q.dim(D::Minus1)?; let scale_factor = 1.0 / (dim as f64).sqrt(); - let k = k.transpose(D::Minus2, D::Minus1)?.contiguous()?; - let mut attn_weights = (q.contiguous()?.matmul(&k)? * scale_factor)?; - attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?.contiguous()?; - let attn_weights = attn_weights.matmul(&v.contiguous()?)?; - Ok(attn_weights) + let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?; + candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v) } #[derive(Debug, Clone, PartialEq, serde::Deserialize)] @@ -101,10 +98,15 @@ impl Module for Attention { .apply(&self.qkv)? .reshape((b, n, 3, self.num_heads, self.head_dim))? .permute((2, 0, 3, 1, 4))?; - let (q, k, v) = (qkv.i(0)?, qkv.i(1)?, qkv.i(2)?); - let attn_weights = scaled_dot_product_attention(&q, &k, &v)?; - let attn_weights = attn_weights.transpose(1, 2)?.reshape((b, n, c))?; - attn_weights.apply(&self.proj) + let (q, k, v) = ( + qkv.i(0)?.contiguous()?, + qkv.i(1)?.contiguous()?, + qkv.i(2)?.contiguous()?, + ); + scaled_dot_product_attention(&q, &k, &v)? + .transpose(1, 2)? + .reshape((b, n, c))? + .apply(&self.proj) } } @@ -275,11 +277,11 @@ impl Module for VisionEncoder { let (p1, p2) = (14, 14); let h = hp1 / p1; let w = wp2 / p2; - let xs = xs - .reshape((b, c, h, p1, h, p2))? + xs.reshape((b, c, h, p1, h, p2))? .permute((0, 2, 4, 1, 3, 5))? - .reshape((b, h * w, c * p1 * p2))?; - xs.apply(&self.encoder)?.apply(&self.projection) + .reshape((b, h * w, c * p1 * p2))? + .apply(&self.encoder)? + .apply(&self.projection) } }