mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Print the recognized categories in dino-v2. (#506)
This commit is contained in:
@ -315,7 +315,17 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let model = vit_small(vb)?;
|
let model = vit_small(vb)?;
|
||||||
println!("model built");
|
println!("model built");
|
||||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
println!("{prs}");
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::IMAGENET_CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user