Print the recognized categories in dino-v2. (#506)

This commit is contained in:
Laurent Mazare
2023-08-18 17:32:58 +01:00
committed by GitHub
parent cb069d6063
commit e5dd5fd1b3
2 changed files with 1018 additions and 3 deletions

View File

@ -315,7 +315,17 @@ pub fn main() -> anyhow::Result<()> {
let model = vit_small(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?;
println!("{prs}");
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::IMAGENET_CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

File diff suppressed because it is too large Load Diff