Add the pose estimation head for yolo. (#589)

* Add the pose estimation head for yolo.

* Properly handle the added position dimensions.

* Integrate the pose estimation head in the forward pass.

* Renaming.

* Fix for pose estimation.
This commit is contained in:
Laurent Mazare
2023-08-24 22:12:34 +01:00
committed by GitHub
parent 2cde0cb74b
commit 189442a0fa

View File

@ -428,7 +428,6 @@ impl YoloV8Neck {
}
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
println!("{p3:?} {p4:?} {p5:?}");
let x = self
.n1
.forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?;
@ -454,6 +453,13 @@ struct DetectionHead {
no: usize,
}
#[derive(Debug)]
struct PoseHead {
detect: DetectionHead,
cv4: [(ConvBlock, ConvBlock, Conv2d); 3],
kpt: (usize, usize),
}
fn make_anchors(
xs0: &Tensor,
xs1: &Tensor,
@ -495,6 +501,12 @@ fn dist2bbox(distance: &Tensor, anchor_points: &Tensor) -> Result<Tensor> {
Tensor::cat(&[c_xy, wh], 1)
}
struct DetectionHeadOut {
pred: Tensor,
anchors: Tensor,
strides: Tensor,
}
impl DetectionHead {
fn load(vb: VarBuilder, nc: usize, filters: (usize, usize, usize)) -> Result<Self> {
let ch = 16;
@ -545,7 +557,7 @@ impl DetectionHead {
Ok((block0, block1, conv))
}
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> {
let forward_cv = |xs, i: usize| {
let xs_2 = self.cv2[i].0.forward(xs)?;
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
@ -561,7 +573,7 @@ impl DetectionHead {
let xs2 = forward_cv(xs2, 2)?;
let (anchors, strides) = make_anchors(&xs0, &xs1, &xs2, (8, 16, 32), 0.5)?;
let anchors = anchors.transpose(0, 1)?;
let anchors = anchors.transpose(0, 1)?.unsqueeze(0)?;
let strides = strides.transpose(0, 1)?;
let reshape = |xs: &Tensor| {
@ -577,9 +589,70 @@ impl DetectionHead {
let box_ = x_cat.i((.., ..self.ch * 4))?;
let cls = x_cat.i((.., self.ch * 4..))?;
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?;
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors)?;
let dbox = dbox.broadcast_mul(&strides)?;
Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)
let pred = Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)?;
Ok(DetectionHeadOut {
pred,
anchors,
strides,
})
}
}
impl PoseHead {
// kpt: keypoints, (17, 3)
// nc: num-classes, 80
fn load(
vb: VarBuilder,
nc: usize,
kpt: (usize, usize),
filters: (usize, usize, usize),
) -> Result<Self> {
let detect = DetectionHead::load(vb.clone(), nc, filters)?;
let nk = kpt.0 * kpt.1;
let c4 = usize::max(filters.0 / 4, nk);
let cv4 = [
Self::load_cv4(vb.pp("cv4.0"), c4, nk, filters.0)?,
Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?,
Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?,
];
Ok(Self { detect, cv4, kpt })
}
fn load_cv4(
vb: VarBuilder,
c1: usize,
nc: usize,
filter: usize,
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?;
let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?;
let conv = conv2d(c1, nc, 1, Default::default(), vb.pp("2"))?;
Ok((block0, block1, conv))
}
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
let d = self.detect.forward(xs0, xs1, xs2)?;
let forward_cv = |xs: &Tensor, i: usize| {
let (b_sz, _, h, w) = xs.dims4()?;
let xs = self.cv4[i].0.forward(xs)?;
let xs = self.cv4[i].1.forward(&xs)?;
let xs = self.cv4[i].2.forward(&xs)?;
xs.reshape((b_sz, self.kpt.0 * self.kpt.1, h * w))
};
let xs0 = forward_cv(xs0, 0)?;
let xs1 = forward_cv(xs1, 1)?;
let xs2 = forward_cv(xs2, 2)?;
let xs = Tensor::cat(&[xs0, xs1, xs2], D::Minus1)?;
let (b_sz, _nk, hw) = xs.dims3()?;
let xs = xs.reshape((b_sz, self.kpt.0, self.kpt.1, hw))?;
let ys01 = ((xs.i((.., .., 0..2))? * 2.)?.broadcast_add(&d.anchors)? - 0.5)?
.broadcast_mul(&d.strides)?;
let ys2 = candle_nn::ops::sigmoid(&xs.i((.., .., 2..3))?)?;
let ys = Tensor::cat(&[ys01, ys2], 2)?.flatten(1, 2)?;
Tensor::cat(&[d.pred, ys], 1)
}
}
@ -600,6 +673,31 @@ impl YoloV8 {
}
impl Module for YoloV8 {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (xs1, xs2, xs3) = self.net.forward(xs)?;
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred)
}
}
#[derive(Debug)]
struct YoloV8Pose {
net: DarkNet,
fpn: YoloV8Neck,
head: PoseHead,
}
#[allow(unused)]
impl YoloV8Pose {
fn load(vb: VarBuilder, m: Multiples, num_classes: usize, kpt: (usize, usize)) -> Result<Self> {
let net = DarkNet::load(vb.pp("net"), m)?;
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?;
Ok(Self { net, fpn, head })
}
}
impl Module for YoloV8Pose {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (xs1, xs2, xs3) = self.net.forward(xs)?;
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
@ -757,6 +855,7 @@ pub fn main() -> anyhow::Result<()> {
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
let model = YoloV8::load(vb, multiples, /* num_classes=*/ 80)?;
// let model = YoloV8Pose::load(vb, multiples, /* num_classes=*/ 1, (17, 3))?;
println!("model loaded");
for image_name in args.images.iter() {
println!("processing {image_name}");
@ -790,7 +889,6 @@ pub fn main() -> anyhow::Result<()> {
)?
.permute((2, 0, 1))?
};
println!("{image_t:?}");
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
let predictions = model.forward(&image_t)?.squeeze(0)?;
println!("generated predictions {predictions:?}");