mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the pose estimation head for yolo. (#589)
* Add the pose estimation head for yolo. * Properly handle the added position dimensions. * Integrate the pose estimation head in the forward pass. * Renaming. * Fix for pose estimation.
This commit is contained in:
@ -428,7 +428,6 @@ impl YoloV8Neck {
|
||||
}
|
||||
|
||||
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
println!("{p3:?} {p4:?} {p5:?}");
|
||||
let x = self
|
||||
.n1
|
||||
.forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?;
|
||||
@ -454,6 +453,13 @@ struct DetectionHead {
|
||||
no: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PoseHead {
|
||||
detect: DetectionHead,
|
||||
cv4: [(ConvBlock, ConvBlock, Conv2d); 3],
|
||||
kpt: (usize, usize),
|
||||
}
|
||||
|
||||
fn make_anchors(
|
||||
xs0: &Tensor,
|
||||
xs1: &Tensor,
|
||||
@ -495,6 +501,12 @@ fn dist2bbox(distance: &Tensor, anchor_points: &Tensor) -> Result<Tensor> {
|
||||
Tensor::cat(&[c_xy, wh], 1)
|
||||
}
|
||||
|
||||
struct DetectionHeadOut {
|
||||
pred: Tensor,
|
||||
anchors: Tensor,
|
||||
strides: Tensor,
|
||||
}
|
||||
|
||||
impl DetectionHead {
|
||||
fn load(vb: VarBuilder, nc: usize, filters: (usize, usize, usize)) -> Result<Self> {
|
||||
let ch = 16;
|
||||
@ -545,7 +557,7 @@ impl DetectionHead {
|
||||
Ok((block0, block1, conv))
|
||||
}
|
||||
|
||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> {
|
||||
let forward_cv = |xs, i: usize| {
|
||||
let xs_2 = self.cv2[i].0.forward(xs)?;
|
||||
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
|
||||
@ -561,7 +573,7 @@ impl DetectionHead {
|
||||
let xs2 = forward_cv(xs2, 2)?;
|
||||
|
||||
let (anchors, strides) = make_anchors(&xs0, &xs1, &xs2, (8, 16, 32), 0.5)?;
|
||||
let anchors = anchors.transpose(0, 1)?;
|
||||
let anchors = anchors.transpose(0, 1)?.unsqueeze(0)?;
|
||||
let strides = strides.transpose(0, 1)?;
|
||||
|
||||
let reshape = |xs: &Tensor| {
|
||||
@ -577,9 +589,70 @@ impl DetectionHead {
|
||||
let box_ = x_cat.i((.., ..self.ch * 4))?;
|
||||
let cls = x_cat.i((.., self.ch * 4..))?;
|
||||
|
||||
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?;
|
||||
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors)?;
|
||||
let dbox = dbox.broadcast_mul(&strides)?;
|
||||
Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)
|
||||
let pred = Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)?;
|
||||
Ok(DetectionHeadOut {
|
||||
pred,
|
||||
anchors,
|
||||
strides,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl PoseHead {
|
||||
// kpt: keypoints, (17, 3)
|
||||
// nc: num-classes, 80
|
||||
fn load(
|
||||
vb: VarBuilder,
|
||||
nc: usize,
|
||||
kpt: (usize, usize),
|
||||
filters: (usize, usize, usize),
|
||||
) -> Result<Self> {
|
||||
let detect = DetectionHead::load(vb.clone(), nc, filters)?;
|
||||
let nk = kpt.0 * kpt.1;
|
||||
let c4 = usize::max(filters.0 / 4, nk);
|
||||
let cv4 = [
|
||||
Self::load_cv4(vb.pp("cv4.0"), c4, nk, filters.0)?,
|
||||
Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?,
|
||||
Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?,
|
||||
];
|
||||
Ok(Self { detect, cv4, kpt })
|
||||
}
|
||||
|
||||
fn load_cv4(
|
||||
vb: VarBuilder,
|
||||
c1: usize,
|
||||
nc: usize,
|
||||
filter: usize,
|
||||
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
|
||||
let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?;
|
||||
let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?;
|
||||
let conv = conv2d(c1, nc, 1, Default::default(), vb.pp("2"))?;
|
||||
Ok((block0, block1, conv))
|
||||
}
|
||||
|
||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
|
||||
let d = self.detect.forward(xs0, xs1, xs2)?;
|
||||
let forward_cv = |xs: &Tensor, i: usize| {
|
||||
let (b_sz, _, h, w) = xs.dims4()?;
|
||||
let xs = self.cv4[i].0.forward(xs)?;
|
||||
let xs = self.cv4[i].1.forward(&xs)?;
|
||||
let xs = self.cv4[i].2.forward(&xs)?;
|
||||
xs.reshape((b_sz, self.kpt.0 * self.kpt.1, h * w))
|
||||
};
|
||||
let xs0 = forward_cv(xs0, 0)?;
|
||||
let xs1 = forward_cv(xs1, 1)?;
|
||||
let xs2 = forward_cv(xs2, 2)?;
|
||||
let xs = Tensor::cat(&[xs0, xs1, xs2], D::Minus1)?;
|
||||
let (b_sz, _nk, hw) = xs.dims3()?;
|
||||
let xs = xs.reshape((b_sz, self.kpt.0, self.kpt.1, hw))?;
|
||||
|
||||
let ys01 = ((xs.i((.., .., 0..2))? * 2.)?.broadcast_add(&d.anchors)? - 0.5)?
|
||||
.broadcast_mul(&d.strides)?;
|
||||
let ys2 = candle_nn::ops::sigmoid(&xs.i((.., .., 2..3))?)?;
|
||||
let ys = Tensor::cat(&[ys01, ys2], 2)?.flatten(1, 2)?;
|
||||
Tensor::cat(&[d.pred, ys], 1)
|
||||
}
|
||||
}
|
||||
|
||||
@ -600,6 +673,31 @@ impl YoloV8 {
|
||||
}
|
||||
|
||||
impl Module for YoloV8 {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
||||
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
||||
Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct YoloV8Pose {
|
||||
net: DarkNet,
|
||||
fpn: YoloV8Neck,
|
||||
head: PoseHead,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl YoloV8Pose {
|
||||
fn load(vb: VarBuilder, m: Multiples, num_classes: usize, kpt: (usize, usize)) -> Result<Self> {
|
||||
let net = DarkNet::load(vb.pp("net"), m)?;
|
||||
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
|
||||
let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?;
|
||||
Ok(Self { net, fpn, head })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for YoloV8Pose {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
||||
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
||||
@ -757,6 +855,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||
let model = YoloV8::load(vb, multiples, /* num_classes=*/ 80)?;
|
||||
// let model = YoloV8Pose::load(vb, multiples, /* num_classes=*/ 1, (17, 3))?;
|
||||
println!("model loaded");
|
||||
for image_name in args.images.iter() {
|
||||
println!("processing {image_name}");
|
||||
@ -790,7 +889,6 @@ pub fn main() -> anyhow::Result<()> {
|
||||
)?
|
||||
.permute((2, 0, 1))?
|
||||
};
|
||||
println!("{image_t:?}");
|
||||
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||
let predictions = model.forward(&image_t)?.squeeze(0)?;
|
||||
println!("generated predictions {predictions:?}");
|
||||
|
Reference in New Issue
Block a user