mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
dinov2 - read images from disk and compute the class probabilities (#503)
* Load the image from disk and convert it to a tensor. * Tweak the function name.
This commit is contained in:
@ -85,7 +85,7 @@ impl LayerScale {
|
||||
|
||||
impl Module for LayerScale {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs * &self.gamma
|
||||
xs.broadcast_mul(&self.gamma)
|
||||
}
|
||||
}
|
||||
|
||||
@ -306,10 +306,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
// TODO: apply imagenet normalization.
|
||||
let image = candle_examples::load_image(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let _model = vit_small(vb)?;
|
||||
let model = vit_small(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||
println!("{prs}");
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user