More support for pose estimation in yolo-v8. (#599)

* More support for pose estimation in yolo-v8.

* Support both object detection and pose-estimation in the yolo-v8 example.
This commit is contained in:
Laurent Mazare
2023-08-25 11:21:11 +01:00
committed by GitHub
parent afc10a3232
commit 8bc5fffa45
3 changed files with 164 additions and 16 deletions

View File

@ -16,7 +16,7 @@ use image::{DynamicImage, ImageBuffer};
// https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py
#[derive(Clone, Copy, PartialEq, Debug)]
struct Multiples {
pub struct Multiples {
depth: f64,
width: f64,
ratio: f64,
@ -727,7 +727,7 @@ pub fn draw_rect(
}
}
pub fn report(
pub fn report_detect(
pred: &Tensor,
img: DynamicImage,
w: usize,
@ -757,6 +757,7 @@ pub fn report(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
keypoints: vec![],
};
bboxes[class_index].push(bbox)
}
@ -787,6 +788,85 @@ pub fn report(
Ok(DynamicImage::ImageRgb8(img))
}
pub fn report_pose(
pred: &Tensor,
img: DynamicImage,
w: usize,
h: usize,
confidence_threshold: f32,
nms_threshold: f32,
) -> Result<DynamicImage> {
let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
let confidence = pred[4];
if confidence > confidence_threshold {
let mut class_index = 0;
for i in 0..nclasses {
if pred[4 + i] > pred[4 + class_index] {
class_index = i
}
}
if pred[class_index + 4] > 0. {
let keypoints = (0..17)
.map(|i| (pred[5 + 3 * i], pred[3 * i + 6], pred[3 * i + 7]))
.collect::<Vec<_>>();
let bbox = Bbox {
xmin: pred[0] - pred[2] / 2.,
ymin: pred[1] - pred[3] / 2.,
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
keypoints,
};
bboxes[class_index].push(bbox)
}
}
}
non_maximum_suppression(&mut bboxes, nms_threshold);
// Annotate the original image and print boxes information.
let (initial_h, initial_w) = (img.height(), img.width());
let w_ratio = initial_w as f32 / w as f32;
let h_ratio = initial_h as f32 / h as f32;
let mut img = img.to_rgb8();
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
println!(
"{}: {:?}",
candle_examples::coco_classes::NAMES[class_index],
b
);
let xmin = ((b.xmin * w_ratio) as u32).clamp(0, initial_w - 1);
let ymin = ((b.ymin * h_ratio) as u32).clamp(0, initial_h - 1);
let xmax = ((b.xmax * w_ratio) as u32).clamp(0, initial_w - 1);
let ymax = ((b.ymax * h_ratio) as u32).clamp(0, initial_h - 1);
draw_rect(&mut img, xmin, xmax, ymin, ymax);
for (x, y, z) in b.keypoints.iter() {
if z < &0.6 {
continue;
}
let x = x * w_ratio;
let y = y * w_ratio;
for dx in -2..3 {
for dy in -2..3 {
let x = ((x + dx as f32) as u32).clamp(0, initial_w - 1);
let y = ((y + dy as f32) as u32).clamp(0, initial_h - 1);
let pixel = img.get_pixel_mut(x, y);
*pixel = image::Rgb([0, 255, 0]);
}
}
}
}
}
Ok(DynamicImage::ImageRgb8(img))
}
#[derive(Clone, Copy, ValueEnum, Debug)]
enum Which {
N,
@ -796,9 +876,15 @@ enum Which {
X,
}
#[derive(Clone, Copy, ValueEnum, Debug)]
enum YoloTask {
Detect,
Pose,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
pub struct Args {
/// Model weights, in safetensors format.
#[arg(long)]
model: Option<String>,
@ -816,6 +902,10 @@ struct Args {
/// Threshold for non-maximum suppression.
#[arg(long, default_value_t = 0.45)]
nms_threshold: f32,
/// The task to be run.
#[arg(long, default_value = "detect")]
task: YoloTask,
}
impl Args {
@ -825,23 +915,71 @@ impl Args {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-yolo-v8".to_string());
let filename = match self.which {
Which::N => "yolov8n.safetensors",
Which::S => "yolov8s.safetensors",
Which::M => "yolov8m.safetensors",
Which::L => "yolov8l.safetensors",
Which::X => "yolov8x.safetensors",
let size = match self.which {
Which::N => "n",
Which::S => "s",
Which::M => "m",
Which::L => "l",
Which::X => "x",
};
api.get(filename)?
let task = match self.task {
YoloTask::Pose => "-pose",
YoloTask::Detect => "",
};
api.get(&format!("yolov8{size}{task}.safetensors"))?
}
};
Ok(path)
}
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
pub trait Task: Module + Sized {
fn load(vb: VarBuilder, multiples: Multiples) -> Result<Self>;
fn report(
pred: &Tensor,
img: DynamicImage,
w: usize,
h: usize,
confidence_threshold: f32,
nms_threshold: f32,
) -> Result<DynamicImage>;
}
impl Task for YoloV8 {
fn load(vb: VarBuilder, multiples: Multiples) -> Result<Self> {
YoloV8::load(vb, multiples, /* num_classes=*/ 80)
}
fn report(
pred: &Tensor,
img: DynamicImage,
w: usize,
h: usize,
confidence_threshold: f32,
nms_threshold: f32,
) -> Result<DynamicImage> {
report_detect(pred, img, w, h, confidence_threshold, nms_threshold)
}
}
impl Task for YoloV8Pose {
fn load(vb: VarBuilder, multiples: Multiples) -> Result<Self> {
YoloV8Pose::load(vb, multiples, /* num_classes=*/ 1, (17, 3))
}
fn report(
pred: &Tensor,
img: DynamicImage,
w: usize,
h: usize,
confidence_threshold: f32,
nms_threshold: f32,
) -> Result<DynamicImage> {
report_pose(pred, img, w, h, confidence_threshold, nms_threshold)
}
}
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
// Create the model and load the weights from the file.
let multiples = match args.which {
Which::N => Multiples::n(),
@ -854,8 +992,7 @@ pub fn main() -> anyhow::Result<()> {
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
let model = YoloV8::load(vb, multiples, /* num_classes=*/ 80)?;
// let model = YoloV8Pose::load(vb, multiples, /* num_classes=*/ 1, (17, 3))?;
let model = T::load(vb, multiples)?;
println!("model loaded");
for image_name in args.images.iter() {
println!("processing {image_name}");
@ -892,7 +1029,7 @@ pub fn main() -> anyhow::Result<()> {
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
let predictions = model.forward(&image_t)?.squeeze(0)?;
println!("generated predictions {predictions:?}");
let image_t = report(
let image_t = T::report(
&predictions,
original_image,
width,
@ -907,3 +1044,12 @@ pub fn main() -> anyhow::Result<()> {
Ok(())
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
match args.task {
YoloTask::Detect => run::<YoloV8>(args)?,
YoloTask::Pose => run::<YoloV8Pose>(args)?,
}
Ok(())
}