mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Print the recognized categories in dino-v2. (#506)
This commit is contained in:
@ -315,7 +315,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let model = vit_small(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||
println!("{prs}");
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::IMAGENET_CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user