mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Yolo v8 fixes (#542)
* Fixes for the yolo-v8 layout. * Bugfixes. * Another silly bugfix. * Remove the hf-hub dependency. * Remove the transformers dependency.
This commit is contained in:
@ -9,7 +9,9 @@ extern crate accelerate_src;
|
||||
mod coco_classes;
|
||||
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{batch_norm, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
|
||||
};
|
||||
use clap::Parser;
|
||||
use image::{DynamicImage, ImageBuffer};
|
||||
|
||||
@ -179,8 +181,7 @@ impl C2f {
|
||||
impl Module for C2f {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let ys = self.cv1.forward(xs)?;
|
||||
let ys_1 = ys.dim(1)?;
|
||||
let mut ys = vec![ys.i((.., 0..ys_1 / 2))?, ys.i((.., ys_1 / 2..))?];
|
||||
let mut ys = ys.chunk(2, 1)?;
|
||||
for m in self.bottleneck.iter() {
|
||||
ys.push(m.forward(ys.last().unwrap())?)
|
||||
}
|
||||
@ -349,7 +350,9 @@ impl DarkNet {
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
|
||||
let x2 = self.b2_1.forward(&self.b2_0.forward(&x1)?)?;
|
||||
let x2 = self
|
||||
.b2_2
|
||||
.forward(&self.b2_1.forward(&self.b2_0.forward(&x1)?)?)?;
|
||||
let x3 = self.b3_1.forward(&self.b3_0.forward(&x2)?)?;
|
||||
let x4 = self.b4_1.forward(&self.b4_0.forward(&x3)?)?;
|
||||
let x5 = self.b5.forward(&x4)?;
|
||||
@ -529,7 +532,7 @@ impl DetectionHead {
|
||||
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
|
||||
let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?;
|
||||
let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?;
|
||||
let conv = conv2d_no_bias(c1, nc, 1, Default::default(), vb.pp("2"))?;
|
||||
let conv = conv2d(c1, nc, 1, Default::default(), vb.pp("2"))?;
|
||||
Ok((block0, block1, conv))
|
||||
}
|
||||
|
||||
@ -541,7 +544,7 @@ impl DetectionHead {
|
||||
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
|
||||
let block0 = ConvBlock::load(vb.pp("0"), filter, c2, 3, 1, None)?;
|
||||
let block1 = ConvBlock::load(vb.pp("1"), c2, c2, 3, 1, None)?;
|
||||
let conv = conv2d_no_bias(c2, 4 * ch, 1, Default::default(), vb.pp("2"))?;
|
||||
let conv = conv2d(c2, 4 * ch, 1, Default::default(), vb.pp("2"))?;
|
||||
Ok((block0, block1, conv))
|
||||
}
|
||||
|
||||
@ -579,7 +582,6 @@ impl DetectionHead {
|
||||
|
||||
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?;
|
||||
let dbox = dbox.broadcast_mul(&strides)?;
|
||||
|
||||
Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)
|
||||
}
|
||||
}
|
||||
@ -652,22 +654,22 @@ pub fn draw_rect(
|
||||
}
|
||||
|
||||
pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<DynamicImage> {
|
||||
let (npreds, pred_size) = pred.dims2()?;
|
||||
let nclasses = pred_size - 5;
|
||||
let (pred_size, npreds) = pred.dims2()?;
|
||||
let nclasses = pred_size - 4;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
||||
// Extract the bounding boxes for which confidence is above the threshold.
|
||||
for index in 0..npreds {
|
||||
let pred = Vec::<f32>::try_from(pred.get(index)?)?;
|
||||
let confidence = pred[4];
|
||||
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
|
||||
let confidence = *pred[4..].iter().max_by(|x, y| x.total_cmp(y)).unwrap();
|
||||
if confidence > CONFIDENCE_THRESHOLD {
|
||||
let mut class_index = 0;
|
||||
for i in 0..nclasses {
|
||||
if pred[5 + i] > pred[5 + class_index] {
|
||||
if pred[4 + i] > pred[4 + class_index] {
|
||||
class_index = i
|
||||
}
|
||||
}
|
||||
if pred[class_index + 5] > 0. {
|
||||
if pred[class_index + 4] > 0. {
|
||||
let bbox = Bbox {
|
||||
xmin: pred[0] - pred[2] / 2.,
|
||||
ymin: pred[1] - pred[3] / 2.,
|
||||
@ -767,7 +769,6 @@ pub fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||
let predictions = model.forward(&image)?.squeeze(0)?;
|
||||
let predictions = predictions.t()?;
|
||||
println!("generated predictions {predictions:?}");
|
||||
let image = report(&predictions, original_image, 640, 640)?;
|
||||
image_name.set_extension("pp.jpg");
|
||||
|
@ -12,10 +12,8 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
|
||||
hf-hub = { workspace = true}
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.2" }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
rand = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
|
||||
|
Reference in New Issue
Block a user