Yolo v8 fixes (#542)

* Fixes for the yolo-v8 layout.

* Bugfixes.

* Another silly bugfix.

* Remove the hf-hub dependency.

* Remove the transformers dependency.
This commit is contained in:
Laurent Mazare
2023-08-21 21:05:40 +01:00
committed by GitHub
parent de50e66af1
commit 3507e14c0c
2 changed files with 15 additions and 16 deletions

View File

@ -9,7 +9,9 @@ extern crate accelerate_src;
mod coco_classes;
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{batch_norm, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
};
use clap::Parser;
use image::{DynamicImage, ImageBuffer};
@ -179,8 +181,7 @@ impl C2f {
impl Module for C2f {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = self.cv1.forward(xs)?;
let ys_1 = ys.dim(1)?;
let mut ys = vec![ys.i((.., 0..ys_1 / 2))?, ys.i((.., ys_1 / 2..))?];
let mut ys = ys.chunk(2, 1)?;
for m in self.bottleneck.iter() {
ys.push(m.forward(ys.last().unwrap())?)
}
@ -349,7 +350,9 @@ impl DarkNet {
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
let x2 = self.b2_1.forward(&self.b2_0.forward(&x1)?)?;
let x2 = self
.b2_2
.forward(&self.b2_1.forward(&self.b2_0.forward(&x1)?)?)?;
let x3 = self.b3_1.forward(&self.b3_0.forward(&x2)?)?;
let x4 = self.b4_1.forward(&self.b4_0.forward(&x3)?)?;
let x5 = self.b5.forward(&x4)?;
@ -529,7 +532,7 @@ impl DetectionHead {
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?;
let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?;
let conv = conv2d_no_bias(c1, nc, 1, Default::default(), vb.pp("2"))?;
let conv = conv2d(c1, nc, 1, Default::default(), vb.pp("2"))?;
Ok((block0, block1, conv))
}
@ -541,7 +544,7 @@ impl DetectionHead {
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
let block0 = ConvBlock::load(vb.pp("0"), filter, c2, 3, 1, None)?;
let block1 = ConvBlock::load(vb.pp("1"), c2, c2, 3, 1, None)?;
let conv = conv2d_no_bias(c2, 4 * ch, 1, Default::default(), vb.pp("2"))?;
let conv = conv2d(c2, 4 * ch, 1, Default::default(), vb.pp("2"))?;
Ok((block0, block1, conv))
}
@ -579,7 +582,6 @@ impl DetectionHead {
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?;
let dbox = dbox.broadcast_mul(&strides)?;
Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)
}
}
@ -652,22 +654,22 @@ pub fn draw_rect(
}
pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<DynamicImage> {
let (npreds, pred_size) = pred.dims2()?;
let nclasses = pred_size - 5;
let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.get(index)?)?;
let confidence = pred[4];
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
let confidence = *pred[4..].iter().max_by(|x, y| x.total_cmp(y)).unwrap();
if confidence > CONFIDENCE_THRESHOLD {
let mut class_index = 0;
for i in 0..nclasses {
if pred[5 + i] > pred[5 + class_index] {
if pred[4 + i] > pred[4 + class_index] {
class_index = i
}
}
if pred[class_index + 5] > 0. {
if pred[class_index + 4] > 0. {
let bbox = Bbox {
xmin: pred[0] - pred[2] / 2.,
ymin: pred[1] - pred[3] / 2.,
@ -767,7 +769,6 @@ pub fn main() -> anyhow::Result<()> {
};
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
let predictions = model.forward(&image)?.squeeze(0)?;
let predictions = predictions.t()?;
println!("generated predictions {predictions:?}");
let image = report(&predictions, original_image, 640, 640)?;
image_name.set_extension("pp.jpg");

View File

@ -12,10 +12,8 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
hf-hub = { workspace = true}
candle-nn = { path = "../candle-nn", version = "0.1.2" }
intel-mkl-src = { workspace = true, optional = true }
tokenizers = { workspace = true, features = ["onig"] }
rand = { workspace = true }
wav = { workspace = true }