mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
More support for pose estimation in yolo-v8. (#599)
* More support for pose estimation in yolo-v8. * Support both object detection and pose-estimation in the yolo-v8 example.
This commit is contained in:
@ -65,6 +65,7 @@ pub fn report(
|
|||||||
xmax: pred[0] + pred[2] / 2.,
|
xmax: pred[0] + pred[2] / 2.,
|
||||||
ymax: pred[1] + pred[3] / 2.,
|
ymax: pred[1] + pred[3] / 2.,
|
||||||
confidence,
|
confidence,
|
||||||
|
keypoints: vec![],
|
||||||
};
|
};
|
||||||
bboxes[class_index].push(bbox)
|
bboxes[class_index].push(bbox)
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ use image::{DynamicImage, ImageBuffer};
|
|||||||
// https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py
|
// https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py
|
||||||
|
|
||||||
#[derive(Clone, Copy, PartialEq, Debug)]
|
#[derive(Clone, Copy, PartialEq, Debug)]
|
||||||
struct Multiples {
|
pub struct Multiples {
|
||||||
depth: f64,
|
depth: f64,
|
||||||
width: f64,
|
width: f64,
|
||||||
ratio: f64,
|
ratio: f64,
|
||||||
@ -727,7 +727,7 @@ pub fn draw_rect(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn report(
|
pub fn report_detect(
|
||||||
pred: &Tensor,
|
pred: &Tensor,
|
||||||
img: DynamicImage,
|
img: DynamicImage,
|
||||||
w: usize,
|
w: usize,
|
||||||
@ -757,6 +757,7 @@ pub fn report(
|
|||||||
xmax: pred[0] + pred[2] / 2.,
|
xmax: pred[0] + pred[2] / 2.,
|
||||||
ymax: pred[1] + pred[3] / 2.,
|
ymax: pred[1] + pred[3] / 2.,
|
||||||
confidence,
|
confidence,
|
||||||
|
keypoints: vec![],
|
||||||
};
|
};
|
||||||
bboxes[class_index].push(bbox)
|
bboxes[class_index].push(bbox)
|
||||||
}
|
}
|
||||||
@ -787,6 +788,85 @@ pub fn report(
|
|||||||
Ok(DynamicImage::ImageRgb8(img))
|
Ok(DynamicImage::ImageRgb8(img))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn report_pose(
|
||||||
|
pred: &Tensor,
|
||||||
|
img: DynamicImage,
|
||||||
|
w: usize,
|
||||||
|
h: usize,
|
||||||
|
confidence_threshold: f32,
|
||||||
|
nms_threshold: f32,
|
||||||
|
) -> Result<DynamicImage> {
|
||||||
|
let (pred_size, npreds) = pred.dims2()?;
|
||||||
|
let nclasses = pred_size - 4;
|
||||||
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
|
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
||||||
|
// Extract the bounding boxes for which confidence is above the threshold.
|
||||||
|
for index in 0..npreds {
|
||||||
|
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
|
||||||
|
let confidence = pred[4];
|
||||||
|
if confidence > confidence_threshold {
|
||||||
|
let mut class_index = 0;
|
||||||
|
for i in 0..nclasses {
|
||||||
|
if pred[4 + i] > pred[4 + class_index] {
|
||||||
|
class_index = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pred[class_index + 4] > 0. {
|
||||||
|
let keypoints = (0..17)
|
||||||
|
.map(|i| (pred[5 + 3 * i], pred[3 * i + 6], pred[3 * i + 7]))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let bbox = Bbox {
|
||||||
|
xmin: pred[0] - pred[2] / 2.,
|
||||||
|
ymin: pred[1] - pred[3] / 2.,
|
||||||
|
xmax: pred[0] + pred[2] / 2.,
|
||||||
|
ymax: pred[1] + pred[3] / 2.,
|
||||||
|
confidence,
|
||||||
|
keypoints,
|
||||||
|
};
|
||||||
|
bboxes[class_index].push(bbox)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
non_maximum_suppression(&mut bboxes, nms_threshold);
|
||||||
|
|
||||||
|
// Annotate the original image and print boxes information.
|
||||||
|
let (initial_h, initial_w) = (img.height(), img.width());
|
||||||
|
let w_ratio = initial_w as f32 / w as f32;
|
||||||
|
let h_ratio = initial_h as f32 / h as f32;
|
||||||
|
let mut img = img.to_rgb8();
|
||||||
|
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
|
||||||
|
for b in bboxes_for_class.iter() {
|
||||||
|
println!(
|
||||||
|
"{}: {:?}",
|
||||||
|
candle_examples::coco_classes::NAMES[class_index],
|
||||||
|
b
|
||||||
|
);
|
||||||
|
let xmin = ((b.xmin * w_ratio) as u32).clamp(0, initial_w - 1);
|
||||||
|
let ymin = ((b.ymin * h_ratio) as u32).clamp(0, initial_h - 1);
|
||||||
|
let xmax = ((b.xmax * w_ratio) as u32).clamp(0, initial_w - 1);
|
||||||
|
let ymax = ((b.ymax * h_ratio) as u32).clamp(0, initial_h - 1);
|
||||||
|
draw_rect(&mut img, xmin, xmax, ymin, ymax);
|
||||||
|
for (x, y, z) in b.keypoints.iter() {
|
||||||
|
if z < &0.6 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let x = x * w_ratio;
|
||||||
|
let y = y * w_ratio;
|
||||||
|
for dx in -2..3 {
|
||||||
|
for dy in -2..3 {
|
||||||
|
let x = ((x + dx as f32) as u32).clamp(0, initial_w - 1);
|
||||||
|
let y = ((y + dy as f32) as u32).clamp(0, initial_h - 1);
|
||||||
|
let pixel = img.get_pixel_mut(x, y);
|
||||||
|
*pixel = image::Rgb([0, 255, 0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(DynamicImage::ImageRgb8(img))
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, ValueEnum, Debug)]
|
#[derive(Clone, Copy, ValueEnum, Debug)]
|
||||||
enum Which {
|
enum Which {
|
||||||
N,
|
N,
|
||||||
@ -796,9 +876,15 @@ enum Which {
|
|||||||
X,
|
X,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, ValueEnum, Debug)]
|
||||||
|
enum YoloTask {
|
||||||
|
Detect,
|
||||||
|
Pose,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
pub struct Args {
|
||||||
/// Model weights, in safetensors format.
|
/// Model weights, in safetensors format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
@ -816,6 +902,10 @@ struct Args {
|
|||||||
/// Threshold for non-maximum suppression.
|
/// Threshold for non-maximum suppression.
|
||||||
#[arg(long, default_value_t = 0.45)]
|
#[arg(long, default_value_t = 0.45)]
|
||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
|
|
||||||
|
/// The task to be run.
|
||||||
|
#[arg(long, default_value = "detect")]
|
||||||
|
task: YoloTask,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -825,23 +915,71 @@ impl Args {
|
|||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model("lmz/candle-yolo-v8".to_string());
|
let api = api.model("lmz/candle-yolo-v8".to_string());
|
||||||
let filename = match self.which {
|
let size = match self.which {
|
||||||
Which::N => "yolov8n.safetensors",
|
Which::N => "n",
|
||||||
Which::S => "yolov8s.safetensors",
|
Which::S => "s",
|
||||||
Which::M => "yolov8m.safetensors",
|
Which::M => "m",
|
||||||
Which::L => "yolov8l.safetensors",
|
Which::L => "l",
|
||||||
Which::X => "yolov8x.safetensors",
|
Which::X => "x",
|
||||||
};
|
};
|
||||||
api.get(filename)?
|
let task = match self.task {
|
||||||
|
YoloTask::Pose => "-pose",
|
||||||
|
YoloTask::Detect => "",
|
||||||
|
};
|
||||||
|
api.get(&format!("yolov8{size}{task}.safetensors"))?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Ok(path)
|
Ok(path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub trait Task: Module + Sized {
|
||||||
let args = Args::parse();
|
fn load(vb: VarBuilder, multiples: Multiples) -> Result<Self>;
|
||||||
|
fn report(
|
||||||
|
pred: &Tensor,
|
||||||
|
img: DynamicImage,
|
||||||
|
w: usize,
|
||||||
|
h: usize,
|
||||||
|
confidence_threshold: f32,
|
||||||
|
nms_threshold: f32,
|
||||||
|
) -> Result<DynamicImage>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Task for YoloV8 {
|
||||||
|
fn load(vb: VarBuilder, multiples: Multiples) -> Result<Self> {
|
||||||
|
YoloV8::load(vb, multiples, /* num_classes=*/ 80)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn report(
|
||||||
|
pred: &Tensor,
|
||||||
|
img: DynamicImage,
|
||||||
|
w: usize,
|
||||||
|
h: usize,
|
||||||
|
confidence_threshold: f32,
|
||||||
|
nms_threshold: f32,
|
||||||
|
) -> Result<DynamicImage> {
|
||||||
|
report_detect(pred, img, w, h, confidence_threshold, nms_threshold)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Task for YoloV8Pose {
|
||||||
|
fn load(vb: VarBuilder, multiples: Multiples) -> Result<Self> {
|
||||||
|
YoloV8Pose::load(vb, multiples, /* num_classes=*/ 1, (17, 3))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn report(
|
||||||
|
pred: &Tensor,
|
||||||
|
img: DynamicImage,
|
||||||
|
w: usize,
|
||||||
|
h: usize,
|
||||||
|
confidence_threshold: f32,
|
||||||
|
nms_threshold: f32,
|
||||||
|
) -> Result<DynamicImage> {
|
||||||
|
report_pose(pred, img, w, h, confidence_threshold, nms_threshold)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||||
// Create the model and load the weights from the file.
|
// Create the model and load the weights from the file.
|
||||||
let multiples = match args.which {
|
let multiples = match args.which {
|
||||||
Which::N => Multiples::n(),
|
Which::N => Multiples::n(),
|
||||||
@ -854,8 +992,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||||
let model = YoloV8::load(vb, multiples, /* num_classes=*/ 80)?;
|
let model = T::load(vb, multiples)?;
|
||||||
// let model = YoloV8Pose::load(vb, multiples, /* num_classes=*/ 1, (17, 3))?;
|
|
||||||
println!("model loaded");
|
println!("model loaded");
|
||||||
for image_name in args.images.iter() {
|
for image_name in args.images.iter() {
|
||||||
println!("processing {image_name}");
|
println!("processing {image_name}");
|
||||||
@ -892,7 +1029,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||||
let predictions = model.forward(&image_t)?.squeeze(0)?;
|
let predictions = model.forward(&image_t)?.squeeze(0)?;
|
||||||
println!("generated predictions {predictions:?}");
|
println!("generated predictions {predictions:?}");
|
||||||
let image_t = report(
|
let image_t = T::report(
|
||||||
&predictions,
|
&predictions,
|
||||||
original_image,
|
original_image,
|
||||||
width,
|
width,
|
||||||
@ -907,3 +1044,12 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
match args.task {
|
||||||
|
YoloTask::Detect => run::<YoloV8>(args)?,
|
||||||
|
YoloTask::Pose => run::<YoloV8Pose>(args)?,
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
/// A bounding box around an object.
|
/// A bounding box around an object.
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Bbox {
|
pub struct Bbox {
|
||||||
pub xmin: f32,
|
pub xmin: f32,
|
||||||
pub ymin: f32,
|
pub ymin: f32,
|
||||||
pub xmax: f32,
|
pub xmax: f32,
|
||||||
pub ymax: f32,
|
pub ymax: f32,
|
||||||
pub confidence: f32,
|
pub confidence: f32,
|
||||||
|
pub keypoints: Vec<(f32, f32, f32)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Intersection over union of two bounding boxes.
|
/// Intersection over union of two bounding boxes.
|
||||||
|
Reference in New Issue
Block a user