mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add options to use local files + specify a custom repo or branch. (#1973)
This commit is contained in:
@ -155,6 +155,18 @@ struct Args {
|
|||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 64)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "vikhyatk/moondream2")]
|
||||||
|
model_id: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
@ -204,9 +216,19 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = hf_hub::api::tokio::Api::new()?;
|
let api = hf_hub::api::tokio::Api::new()?;
|
||||||
let repo = api.model("vikhyatk/moondream2".to_string());
|
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||||
let model_file = repo.get("model.safetensors").await?;
|
args.model_id,
|
||||||
let tokenizer = repo.get("tokenizer.json").await?;
|
hf_hub::RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let model_file = match args.model_file {
|
||||||
|
Some(m) => m.into(),
|
||||||
|
None => repo.get("model.safetensors").await?,
|
||||||
|
};
|
||||||
|
let tokenizer = match args.tokenizer_file {
|
||||||
|
Some(m) => m.into(),
|
||||||
|
None => repo.get("tokenizer.json").await?,
|
||||||
|
};
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
|
||||||
|
@ -19,11 +19,8 @@ impl Config {
|
|||||||
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||||
let dim = q.dim(D::Minus1)?;
|
let dim = q.dim(D::Minus1)?;
|
||||||
let scale_factor = 1.0 / (dim as f64).sqrt();
|
let scale_factor = 1.0 / (dim as f64).sqrt();
|
||||||
let k = k.transpose(D::Minus2, D::Minus1)?.contiguous()?;
|
let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
|
||||||
let mut attn_weights = (q.contiguous()?.matmul(&k)? * scale_factor)?;
|
candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
|
||||||
attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?.contiguous()?;
|
|
||||||
let attn_weights = attn_weights.matmul(&v.contiguous()?)?;
|
|
||||||
Ok(attn_weights)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||||
@ -101,10 +98,15 @@ impl Module for Attention {
|
|||||||
.apply(&self.qkv)?
|
.apply(&self.qkv)?
|
||||||
.reshape((b, n, 3, self.num_heads, self.head_dim))?
|
.reshape((b, n, 3, self.num_heads, self.head_dim))?
|
||||||
.permute((2, 0, 3, 1, 4))?;
|
.permute((2, 0, 3, 1, 4))?;
|
||||||
let (q, k, v) = (qkv.i(0)?, qkv.i(1)?, qkv.i(2)?);
|
let (q, k, v) = (
|
||||||
let attn_weights = scaled_dot_product_attention(&q, &k, &v)?;
|
qkv.i(0)?.contiguous()?,
|
||||||
let attn_weights = attn_weights.transpose(1, 2)?.reshape((b, n, c))?;
|
qkv.i(1)?.contiguous()?,
|
||||||
attn_weights.apply(&self.proj)
|
qkv.i(2)?.contiguous()?,
|
||||||
|
);
|
||||||
|
scaled_dot_product_attention(&q, &k, &v)?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b, n, c))?
|
||||||
|
.apply(&self.proj)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -275,11 +277,11 @@ impl Module for VisionEncoder {
|
|||||||
let (p1, p2) = (14, 14);
|
let (p1, p2) = (14, 14);
|
||||||
let h = hp1 / p1;
|
let h = hp1 / p1;
|
||||||
let w = wp2 / p2;
|
let w = wp2 / p2;
|
||||||
let xs = xs
|
xs.reshape((b, c, h, p1, h, p2))?
|
||||||
.reshape((b, c, h, p1, h, p2))?
|
|
||||||
.permute((0, 2, 4, 1, 3, 5))?
|
.permute((0, 2, 4, 1, 3, 5))?
|
||||||
.reshape((b, h * w, c * p1 * p2))?;
|
.reshape((b, h * w, c * p1 * p2))?
|
||||||
xs.apply(&self.encoder)?.apply(&self.projection)
|
.apply(&self.encoder)?
|
||||||
|
.apply(&self.projection)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user