mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
More support for pose estimation in yolo-v8. (#599)
* More support for pose estimation in yolo-v8. * Support both object detection and pose-estimation in the yolo-v8 example.
This commit is contained in:
@ -16,7 +16,7 @@ use image::{DynamicImage, ImageBuffer};
|
||||
// https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Debug)]
|
||||
struct Multiples {
|
||||
pub struct Multiples {
|
||||
depth: f64,
|
||||
width: f64,
|
||||
ratio: f64,
|
||||
@ -727,7 +727,7 @@ pub fn draw_rect(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn report(
|
||||
pub fn report_detect(
|
||||
pred: &Tensor,
|
||||
img: DynamicImage,
|
||||
w: usize,
|
||||
@ -757,6 +757,7 @@ pub fn report(
|
||||
xmax: pred[0] + pred[2] / 2.,
|
||||
ymax: pred[1] + pred[3] / 2.,
|
||||
confidence,
|
||||
keypoints: vec![],
|
||||
};
|
||||
bboxes[class_index].push(bbox)
|
||||
}
|
||||
@ -787,6 +788,85 @@ pub fn report(
|
||||
Ok(DynamicImage::ImageRgb8(img))
|
||||
}
|
||||
|
||||
pub fn report_pose(
|
||||
pred: &Tensor,
|
||||
img: DynamicImage,
|
||||
w: usize,
|
||||
h: usize,
|
||||
confidence_threshold: f32,
|
||||
nms_threshold: f32,
|
||||
) -> Result<DynamicImage> {
|
||||
let (pred_size, npreds) = pred.dims2()?;
|
||||
let nclasses = pred_size - 4;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
||||
// Extract the bounding boxes for which confidence is above the threshold.
|
||||
for index in 0..npreds {
|
||||
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
|
||||
let confidence = pred[4];
|
||||
if confidence > confidence_threshold {
|
||||
let mut class_index = 0;
|
||||
for i in 0..nclasses {
|
||||
if pred[4 + i] > pred[4 + class_index] {
|
||||
class_index = i
|
||||
}
|
||||
}
|
||||
if pred[class_index + 4] > 0. {
|
||||
let keypoints = (0..17)
|
||||
.map(|i| (pred[5 + 3 * i], pred[3 * i + 6], pred[3 * i + 7]))
|
||||
.collect::<Vec<_>>();
|
||||
let bbox = Bbox {
|
||||
xmin: pred[0] - pred[2] / 2.,
|
||||
ymin: pred[1] - pred[3] / 2.,
|
||||
xmax: pred[0] + pred[2] / 2.,
|
||||
ymax: pred[1] + pred[3] / 2.,
|
||||
confidence,
|
||||
keypoints,
|
||||
};
|
||||
bboxes[class_index].push(bbox)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
non_maximum_suppression(&mut bboxes, nms_threshold);
|
||||
|
||||
// Annotate the original image and print boxes information.
|
||||
let (initial_h, initial_w) = (img.height(), img.width());
|
||||
let w_ratio = initial_w as f32 / w as f32;
|
||||
let h_ratio = initial_h as f32 / h as f32;
|
||||
let mut img = img.to_rgb8();
|
||||
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
|
||||
for b in bboxes_for_class.iter() {
|
||||
println!(
|
||||
"{}: {:?}",
|
||||
candle_examples::coco_classes::NAMES[class_index],
|
||||
b
|
||||
);
|
||||
let xmin = ((b.xmin * w_ratio) as u32).clamp(0, initial_w - 1);
|
||||
let ymin = ((b.ymin * h_ratio) as u32).clamp(0, initial_h - 1);
|
||||
let xmax = ((b.xmax * w_ratio) as u32).clamp(0, initial_w - 1);
|
||||
let ymax = ((b.ymax * h_ratio) as u32).clamp(0, initial_h - 1);
|
||||
draw_rect(&mut img, xmin, xmax, ymin, ymax);
|
||||
for (x, y, z) in b.keypoints.iter() {
|
||||
if z < &0.6 {
|
||||
continue;
|
||||
}
|
||||
let x = x * w_ratio;
|
||||
let y = y * w_ratio;
|
||||
for dx in -2..3 {
|
||||
for dy in -2..3 {
|
||||
let x = ((x + dx as f32) as u32).clamp(0, initial_w - 1);
|
||||
let y = ((y + dy as f32) as u32).clamp(0, initial_h - 1);
|
||||
let pixel = img.get_pixel_mut(x, y);
|
||||
*pixel = image::Rgb([0, 255, 0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(DynamicImage::ImageRgb8(img))
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, ValueEnum, Debug)]
|
||||
enum Which {
|
||||
N,
|
||||
@ -796,9 +876,15 @@ enum Which {
|
||||
X,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, ValueEnum, Debug)]
|
||||
enum YoloTask {
|
||||
Detect,
|
||||
Pose,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
pub struct Args {
|
||||
/// Model weights, in safetensors format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
@ -816,6 +902,10 @@ struct Args {
|
||||
/// Threshold for non-maximum suppression.
|
||||
#[arg(long, default_value_t = 0.45)]
|
||||
nms_threshold: f32,
|
||||
|
||||
/// The task to be run.
|
||||
#[arg(long, default_value = "detect")]
|
||||
task: YoloTask,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -825,23 +915,71 @@ impl Args {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-yolo-v8".to_string());
|
||||
let filename = match self.which {
|
||||
Which::N => "yolov8n.safetensors",
|
||||
Which::S => "yolov8s.safetensors",
|
||||
Which::M => "yolov8m.safetensors",
|
||||
Which::L => "yolov8l.safetensors",
|
||||
Which::X => "yolov8x.safetensors",
|
||||
let size = match self.which {
|
||||
Which::N => "n",
|
||||
Which::S => "s",
|
||||
Which::M => "m",
|
||||
Which::L => "l",
|
||||
Which::X => "x",
|
||||
};
|
||||
api.get(filename)?
|
||||
let task = match self.task {
|
||||
YoloTask::Pose => "-pose",
|
||||
YoloTask::Detect => "",
|
||||
};
|
||||
api.get(&format!("yolov8{size}{task}.safetensors"))?
|
||||
}
|
||||
};
|
||||
Ok(path)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
pub trait Task: Module + Sized {
|
||||
fn load(vb: VarBuilder, multiples: Multiples) -> Result<Self>;
|
||||
fn report(
|
||||
pred: &Tensor,
|
||||
img: DynamicImage,
|
||||
w: usize,
|
||||
h: usize,
|
||||
confidence_threshold: f32,
|
||||
nms_threshold: f32,
|
||||
) -> Result<DynamicImage>;
|
||||
}
|
||||
|
||||
impl Task for YoloV8 {
|
||||
fn load(vb: VarBuilder, multiples: Multiples) -> Result<Self> {
|
||||
YoloV8::load(vb, multiples, /* num_classes=*/ 80)
|
||||
}
|
||||
|
||||
fn report(
|
||||
pred: &Tensor,
|
||||
img: DynamicImage,
|
||||
w: usize,
|
||||
h: usize,
|
||||
confidence_threshold: f32,
|
||||
nms_threshold: f32,
|
||||
) -> Result<DynamicImage> {
|
||||
report_detect(pred, img, w, h, confidence_threshold, nms_threshold)
|
||||
}
|
||||
}
|
||||
|
||||
impl Task for YoloV8Pose {
|
||||
fn load(vb: VarBuilder, multiples: Multiples) -> Result<Self> {
|
||||
YoloV8Pose::load(vb, multiples, /* num_classes=*/ 1, (17, 3))
|
||||
}
|
||||
|
||||
fn report(
|
||||
pred: &Tensor,
|
||||
img: DynamicImage,
|
||||
w: usize,
|
||||
h: usize,
|
||||
confidence_threshold: f32,
|
||||
nms_threshold: f32,
|
||||
) -> Result<DynamicImage> {
|
||||
report_pose(pred, img, w, h, confidence_threshold, nms_threshold)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
// Create the model and load the weights from the file.
|
||||
let multiples = match args.which {
|
||||
Which::N => Multiples::n(),
|
||||
@ -854,8 +992,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||
let model = YoloV8::load(vb, multiples, /* num_classes=*/ 80)?;
|
||||
// let model = YoloV8Pose::load(vb, multiples, /* num_classes=*/ 1, (17, 3))?;
|
||||
let model = T::load(vb, multiples)?;
|
||||
println!("model loaded");
|
||||
for image_name in args.images.iter() {
|
||||
println!("processing {image_name}");
|
||||
@ -892,7 +1029,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||
let predictions = model.forward(&image_t)?.squeeze(0)?;
|
||||
println!("generated predictions {predictions:?}");
|
||||
let image_t = report(
|
||||
let image_t = T::report(
|
||||
&predictions,
|
||||
original_image,
|
||||
width,
|
||||
@ -907,3 +1044,12 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
match args.task {
|
||||
YoloTask::Detect => run::<YoloV8>(args)?,
|
||||
YoloTask::Pose => run::<YoloV8Pose>(args)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user