mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Yolo v8 fixes (#542)
* Fixes for the yolo-v8 layout. * Bugfixes. * Another silly bugfix. * Remove the hf-hub dependency. * Remove the transformers dependency.
This commit is contained in:
@ -9,7 +9,9 @@ extern crate accelerate_src;
|
|||||||
mod coco_classes;
|
mod coco_classes;
|
||||||
|
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{batch_norm, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder};
|
use candle_nn::{
|
||||||
|
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
|
||||||
|
};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use image::{DynamicImage, ImageBuffer};
|
use image::{DynamicImage, ImageBuffer};
|
||||||
|
|
||||||
@ -179,8 +181,7 @@ impl C2f {
|
|||||||
impl Module for C2f {
|
impl Module for C2f {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let ys = self.cv1.forward(xs)?;
|
let ys = self.cv1.forward(xs)?;
|
||||||
let ys_1 = ys.dim(1)?;
|
let mut ys = ys.chunk(2, 1)?;
|
||||||
let mut ys = vec![ys.i((.., 0..ys_1 / 2))?, ys.i((.., ys_1 / 2..))?];
|
|
||||||
for m in self.bottleneck.iter() {
|
for m in self.bottleneck.iter() {
|
||||||
ys.push(m.forward(ys.last().unwrap())?)
|
ys.push(m.forward(ys.last().unwrap())?)
|
||||||
}
|
}
|
||||||
@ -349,7 +350,9 @@ impl DarkNet {
|
|||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||||
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
|
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
|
||||||
let x2 = self.b2_1.forward(&self.b2_0.forward(&x1)?)?;
|
let x2 = self
|
||||||
|
.b2_2
|
||||||
|
.forward(&self.b2_1.forward(&self.b2_0.forward(&x1)?)?)?;
|
||||||
let x3 = self.b3_1.forward(&self.b3_0.forward(&x2)?)?;
|
let x3 = self.b3_1.forward(&self.b3_0.forward(&x2)?)?;
|
||||||
let x4 = self.b4_1.forward(&self.b4_0.forward(&x3)?)?;
|
let x4 = self.b4_1.forward(&self.b4_0.forward(&x3)?)?;
|
||||||
let x5 = self.b5.forward(&x4)?;
|
let x5 = self.b5.forward(&x4)?;
|
||||||
@ -529,7 +532,7 @@ impl DetectionHead {
|
|||||||
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
|
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
|
||||||
let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?;
|
let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?;
|
||||||
let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?;
|
let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?;
|
||||||
let conv = conv2d_no_bias(c1, nc, 1, Default::default(), vb.pp("2"))?;
|
let conv = conv2d(c1, nc, 1, Default::default(), vb.pp("2"))?;
|
||||||
Ok((block0, block1, conv))
|
Ok((block0, block1, conv))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -541,7 +544,7 @@ impl DetectionHead {
|
|||||||
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
|
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
|
||||||
let block0 = ConvBlock::load(vb.pp("0"), filter, c2, 3, 1, None)?;
|
let block0 = ConvBlock::load(vb.pp("0"), filter, c2, 3, 1, None)?;
|
||||||
let block1 = ConvBlock::load(vb.pp("1"), c2, c2, 3, 1, None)?;
|
let block1 = ConvBlock::load(vb.pp("1"), c2, c2, 3, 1, None)?;
|
||||||
let conv = conv2d_no_bias(c2, 4 * ch, 1, Default::default(), vb.pp("2"))?;
|
let conv = conv2d(c2, 4 * ch, 1, Default::default(), vb.pp("2"))?;
|
||||||
Ok((block0, block1, conv))
|
Ok((block0, block1, conv))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -579,7 +582,6 @@ impl DetectionHead {
|
|||||||
|
|
||||||
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?;
|
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?;
|
||||||
let dbox = dbox.broadcast_mul(&strides)?;
|
let dbox = dbox.broadcast_mul(&strides)?;
|
||||||
|
|
||||||
Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)
|
Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -652,22 +654,22 @@ pub fn draw_rect(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<DynamicImage> {
|
pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<DynamicImage> {
|
||||||
let (npreds, pred_size) = pred.dims2()?;
|
let (pred_size, npreds) = pred.dims2()?;
|
||||||
let nclasses = pred_size - 5;
|
let nclasses = pred_size - 4;
|
||||||
// The bounding boxes grouped by (maximum) class index.
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
||||||
// Extract the bounding boxes for which confidence is above the threshold.
|
// Extract the bounding boxes for which confidence is above the threshold.
|
||||||
for index in 0..npreds {
|
for index in 0..npreds {
|
||||||
let pred = Vec::<f32>::try_from(pred.get(index)?)?;
|
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
|
||||||
let confidence = pred[4];
|
let confidence = *pred[4..].iter().max_by(|x, y| x.total_cmp(y)).unwrap();
|
||||||
if confidence > CONFIDENCE_THRESHOLD {
|
if confidence > CONFIDENCE_THRESHOLD {
|
||||||
let mut class_index = 0;
|
let mut class_index = 0;
|
||||||
for i in 0..nclasses {
|
for i in 0..nclasses {
|
||||||
if pred[5 + i] > pred[5 + class_index] {
|
if pred[4 + i] > pred[4 + class_index] {
|
||||||
class_index = i
|
class_index = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if pred[class_index + 5] > 0. {
|
if pred[class_index + 4] > 0. {
|
||||||
let bbox = Bbox {
|
let bbox = Bbox {
|
||||||
xmin: pred[0] - pred[2] / 2.,
|
xmin: pred[0] - pred[2] / 2.,
|
||||||
ymin: pred[1] - pred[3] / 2.,
|
ymin: pred[1] - pred[3] / 2.,
|
||||||
@ -767,7 +769,6 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||||
let predictions = model.forward(&image)?.squeeze(0)?;
|
let predictions = model.forward(&image)?.squeeze(0)?;
|
||||||
let predictions = predictions.t()?;
|
|
||||||
println!("generated predictions {predictions:?}");
|
println!("generated predictions {predictions:?}");
|
||||||
let image = report(&predictions, original_image, 640, 640)?;
|
let image = report(&predictions, original_image, 640, 640)?;
|
||||||
image_name.set_extension("pp.jpg");
|
image_name.set_extension("pp.jpg");
|
||||||
|
@ -12,10 +12,8 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
|
||||||
hf-hub = { workspace = true}
|
|
||||||
candle-nn = { path = "../candle-nn", version = "0.1.2" }
|
candle-nn = { path = "../candle-nn", version = "0.1.2" }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
wav = { workspace = true }
|
wav = { workspace = true }
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user