From 189442a0fa44eb1602221335aa3e025e54c475b2 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 24 Aug 2023 22:12:34 +0100 Subject: [PATCH] Add the pose estimation head for yolo. (#589) * Add the pose estimation head for yolo. * Properly handle the added position dimensions. * Integrate the pose estimation head in the forward pass. * Renaming. * Fix for pose estimation. --- candle-examples/examples/yolo-v8/main.rs | 110 +++++++++++++++++++++-- 1 file changed, 104 insertions(+), 6 deletions(-) diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index 3b9c1ce9..a93aa035 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -428,7 +428,6 @@ impl YoloV8Neck { } fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { - println!("{p3:?} {p4:?} {p5:?}"); let x = self .n1 .forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?; @@ -454,6 +453,13 @@ struct DetectionHead { no: usize, } +#[derive(Debug)] +struct PoseHead { + detect: DetectionHead, + cv4: [(ConvBlock, ConvBlock, Conv2d); 3], + kpt: (usize, usize), +} + fn make_anchors( xs0: &Tensor, xs1: &Tensor, @@ -495,6 +501,12 @@ fn dist2bbox(distance: &Tensor, anchor_points: &Tensor) -> Result { Tensor::cat(&[c_xy, wh], 1) } +struct DetectionHeadOut { + pred: Tensor, + anchors: Tensor, + strides: Tensor, +} + impl DetectionHead { fn load(vb: VarBuilder, nc: usize, filters: (usize, usize, usize)) -> Result { let ch = 16; @@ -545,7 +557,7 @@ impl DetectionHead { Ok((block0, block1, conv)) } - fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result { + fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result { let forward_cv = |xs, i: usize| { let xs_2 = self.cv2[i].0.forward(xs)?; let xs_2 = self.cv2[i].1.forward(&xs_2)?; @@ -561,7 +573,7 @@ impl DetectionHead { let xs2 = forward_cv(xs2, 2)?; let (anchors, strides) = make_anchors(&xs0, &xs1, &xs2, (8, 16, 32), 0.5)?; - let anchors = anchors.transpose(0, 1)?; + let anchors = anchors.transpose(0, 1)?.unsqueeze(0)?; let strides = strides.transpose(0, 1)?; let reshape = |xs: &Tensor| { @@ -577,9 +589,70 @@ impl DetectionHead { let box_ = x_cat.i((.., ..self.ch * 4))?; let cls = x_cat.i((.., self.ch * 4..))?; - let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?; + let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors)?; let dbox = dbox.broadcast_mul(&strides)?; - Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1) + let pred = Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)?; + Ok(DetectionHeadOut { + pred, + anchors, + strides, + }) + } +} + +impl PoseHead { + // kpt: keypoints, (17, 3) + // nc: num-classes, 80 + fn load( + vb: VarBuilder, + nc: usize, + kpt: (usize, usize), + filters: (usize, usize, usize), + ) -> Result { + let detect = DetectionHead::load(vb.clone(), nc, filters)?; + let nk = kpt.0 * kpt.1; + let c4 = usize::max(filters.0 / 4, nk); + let cv4 = [ + Self::load_cv4(vb.pp("cv4.0"), c4, nk, filters.0)?, + Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?, + Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?, + ]; + Ok(Self { detect, cv4, kpt }) + } + + fn load_cv4( + vb: VarBuilder, + c1: usize, + nc: usize, + filter: usize, + ) -> Result<(ConvBlock, ConvBlock, Conv2d)> { + let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?; + let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?; + let conv = conv2d(c1, nc, 1, Default::default(), vb.pp("2"))?; + Ok((block0, block1, conv)) + } + + fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result { + let d = self.detect.forward(xs0, xs1, xs2)?; + let forward_cv = |xs: &Tensor, i: usize| { + let (b_sz, _, h, w) = xs.dims4()?; + let xs = self.cv4[i].0.forward(xs)?; + let xs = self.cv4[i].1.forward(&xs)?; + let xs = self.cv4[i].2.forward(&xs)?; + xs.reshape((b_sz, self.kpt.0 * self.kpt.1, h * w)) + }; + let xs0 = forward_cv(xs0, 0)?; + let xs1 = forward_cv(xs1, 1)?; + let xs2 = forward_cv(xs2, 2)?; + let xs = Tensor::cat(&[xs0, xs1, xs2], D::Minus1)?; + let (b_sz, _nk, hw) = xs.dims3()?; + let xs = xs.reshape((b_sz, self.kpt.0, self.kpt.1, hw))?; + + let ys01 = ((xs.i((.., .., 0..2))? * 2.)?.broadcast_add(&d.anchors)? - 0.5)? + .broadcast_mul(&d.strides)?; + let ys2 = candle_nn::ops::sigmoid(&xs.i((.., .., 2..3))?)?; + let ys = Tensor::cat(&[ys01, ys2], 2)?.flatten(1, 2)?; + Tensor::cat(&[d.pred, ys], 1) } } @@ -600,6 +673,31 @@ impl YoloV8 { } impl Module for YoloV8 { + fn forward(&self, xs: &Tensor) -> Result { + let (xs1, xs2, xs3) = self.net.forward(xs)?; + let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?; + Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred) + } +} + +#[derive(Debug)] +struct YoloV8Pose { + net: DarkNet, + fpn: YoloV8Neck, + head: PoseHead, +} + +#[allow(unused)] +impl YoloV8Pose { + fn load(vb: VarBuilder, m: Multiples, num_classes: usize, kpt: (usize, usize)) -> Result { + let net = DarkNet::load(vb.pp("net"), m)?; + let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?; + let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?; + Ok(Self { net, fpn, head }) + } +} + +impl Module for YoloV8Pose { fn forward(&self, xs: &Tensor) -> Result { let (xs1, xs2, xs3) = self.net.forward(xs)?; let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?; @@ -757,6 +855,7 @@ pub fn main() -> anyhow::Result<()> { let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu); let model = YoloV8::load(vb, multiples, /* num_classes=*/ 80)?; + // let model = YoloV8Pose::load(vb, multiples, /* num_classes=*/ 1, (17, 3))?; println!("model loaded"); for image_name in args.images.iter() { println!("processing {image_name}"); @@ -790,7 +889,6 @@ pub fn main() -> anyhow::Result<()> { )? .permute((2, 0, 1))? }; - println!("{image_t:?}"); let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?; let predictions = model.forward(&image_t)?.squeeze(0)?; println!("generated predictions {predictions:?}");