Yolo v8 fixes (#542)

* Fixes for the yolo-v8 layout.

* Bugfixes.

* Another silly bugfix.

* Remove the hf-hub dependency.

* Remove the transformers dependency.
This commit is contained in:
Laurent Mazare
2023-08-21 21:05:40 +01:00
committed by GitHub
parent de50e66af1
commit 3507e14c0c
2 changed files with 15 additions and 16 deletions

View File

@ -9,7 +9,9 @@ extern crate accelerate_src;
mod coco_classes; mod coco_classes;
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{batch_norm, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder}; use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
};
use clap::Parser; use clap::Parser;
use image::{DynamicImage, ImageBuffer}; use image::{DynamicImage, ImageBuffer};
@ -179,8 +181,7 @@ impl C2f {
impl Module for C2f { impl Module for C2f {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = self.cv1.forward(xs)?; let ys = self.cv1.forward(xs)?;
let ys_1 = ys.dim(1)?; let mut ys = ys.chunk(2, 1)?;
let mut ys = vec![ys.i((.., 0..ys_1 / 2))?, ys.i((.., ys_1 / 2..))?];
for m in self.bottleneck.iter() { for m in self.bottleneck.iter() {
ys.push(m.forward(ys.last().unwrap())?) ys.push(m.forward(ys.last().unwrap())?)
} }
@ -349,7 +350,9 @@ impl DarkNet {
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?; let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
let x2 = self.b2_1.forward(&self.b2_0.forward(&x1)?)?; let x2 = self
.b2_2
.forward(&self.b2_1.forward(&self.b2_0.forward(&x1)?)?)?;
let x3 = self.b3_1.forward(&self.b3_0.forward(&x2)?)?; let x3 = self.b3_1.forward(&self.b3_0.forward(&x2)?)?;
let x4 = self.b4_1.forward(&self.b4_0.forward(&x3)?)?; let x4 = self.b4_1.forward(&self.b4_0.forward(&x3)?)?;
let x5 = self.b5.forward(&x4)?; let x5 = self.b5.forward(&x4)?;
@ -529,7 +532,7 @@ impl DetectionHead {
) -> Result<(ConvBlock, ConvBlock, Conv2d)> { ) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?; let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?;
let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?; let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?;
let conv = conv2d_no_bias(c1, nc, 1, Default::default(), vb.pp("2"))?; let conv = conv2d(c1, nc, 1, Default::default(), vb.pp("2"))?;
Ok((block0, block1, conv)) Ok((block0, block1, conv))
} }
@ -541,7 +544,7 @@ impl DetectionHead {
) -> Result<(ConvBlock, ConvBlock, Conv2d)> { ) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
let block0 = ConvBlock::load(vb.pp("0"), filter, c2, 3, 1, None)?; let block0 = ConvBlock::load(vb.pp("0"), filter, c2, 3, 1, None)?;
let block1 = ConvBlock::load(vb.pp("1"), c2, c2, 3, 1, None)?; let block1 = ConvBlock::load(vb.pp("1"), c2, c2, 3, 1, None)?;
let conv = conv2d_no_bias(c2, 4 * ch, 1, Default::default(), vb.pp("2"))?; let conv = conv2d(c2, 4 * ch, 1, Default::default(), vb.pp("2"))?;
Ok((block0, block1, conv)) Ok((block0, block1, conv))
} }
@ -579,7 +582,6 @@ impl DetectionHead {
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?; let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?;
let dbox = dbox.broadcast_mul(&strides)?; let dbox = dbox.broadcast_mul(&strides)?;
Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1) Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)
} }
} }
@ -652,22 +654,22 @@ pub fn draw_rect(
} }
pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<DynamicImage> { pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<DynamicImage> {
let (npreds, pred_size) = pred.dims2()?; let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 5; let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index. // The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect(); let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold. // Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds { for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.get(index)?)?; let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
let confidence = pred[4]; let confidence = *pred[4..].iter().max_by(|x, y| x.total_cmp(y)).unwrap();
if confidence > CONFIDENCE_THRESHOLD { if confidence > CONFIDENCE_THRESHOLD {
let mut class_index = 0; let mut class_index = 0;
for i in 0..nclasses { for i in 0..nclasses {
if pred[5 + i] > pred[5 + class_index] { if pred[4 + i] > pred[4 + class_index] {
class_index = i class_index = i
} }
} }
if pred[class_index + 5] > 0. { if pred[class_index + 4] > 0. {
let bbox = Bbox { let bbox = Bbox {
xmin: pred[0] - pred[2] / 2., xmin: pred[0] - pred[2] / 2.,
ymin: pred[1] - pred[3] / 2., ymin: pred[1] - pred[3] / 2.,
@ -767,7 +769,6 @@ pub fn main() -> anyhow::Result<()> {
}; };
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?; let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
let predictions = model.forward(&image)?.squeeze(0)?; let predictions = model.forward(&image)?.squeeze(0)?;
let predictions = predictions.t()?;
println!("generated predictions {predictions:?}"); println!("generated predictions {predictions:?}");
let image = report(&predictions, original_image, 640, 640)?; let image = report(&predictions, original_image, 640, 640)?;
image_name.set_extension("pp.jpg"); image_name.set_extension("pp.jpg");

View File

@ -12,10 +12,8 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" } candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
hf-hub = { workspace = true}
candle-nn = { path = "../candle-nn", version = "0.1.2" } candle-nn = { path = "../candle-nn", version = "0.1.2" }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
tokenizers = { workspace = true, features = ["onig"] }
rand = { workspace = true } rand = { workspace = true }
wav = { workspace = true } wav = { workspace = true }