mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add the pose estimation head for yolo. (#589)
* Add the pose estimation head for yolo. * Properly handle the added position dimensions. * Integrate the pose estimation head in the forward pass. * Renaming. * Fix for pose estimation.
This commit is contained in:
@ -428,7 +428,6 @@ impl YoloV8Neck {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||||
println!("{p3:?} {p4:?} {p5:?}");
|
|
||||||
let x = self
|
let x = self
|
||||||
.n1
|
.n1
|
||||||
.forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?;
|
.forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?;
|
||||||
@ -454,6 +453,13 @@ struct DetectionHead {
|
|||||||
no: usize,
|
no: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct PoseHead {
|
||||||
|
detect: DetectionHead,
|
||||||
|
cv4: [(ConvBlock, ConvBlock, Conv2d); 3],
|
||||||
|
kpt: (usize, usize),
|
||||||
|
}
|
||||||
|
|
||||||
fn make_anchors(
|
fn make_anchors(
|
||||||
xs0: &Tensor,
|
xs0: &Tensor,
|
||||||
xs1: &Tensor,
|
xs1: &Tensor,
|
||||||
@ -495,6 +501,12 @@ fn dist2bbox(distance: &Tensor, anchor_points: &Tensor) -> Result<Tensor> {
|
|||||||
Tensor::cat(&[c_xy, wh], 1)
|
Tensor::cat(&[c_xy, wh], 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DetectionHeadOut {
|
||||||
|
pred: Tensor,
|
||||||
|
anchors: Tensor,
|
||||||
|
strides: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
impl DetectionHead {
|
impl DetectionHead {
|
||||||
fn load(vb: VarBuilder, nc: usize, filters: (usize, usize, usize)) -> Result<Self> {
|
fn load(vb: VarBuilder, nc: usize, filters: (usize, usize, usize)) -> Result<Self> {
|
||||||
let ch = 16;
|
let ch = 16;
|
||||||
@ -545,7 +557,7 @@ impl DetectionHead {
|
|||||||
Ok((block0, block1, conv))
|
Ok((block0, block1, conv))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> {
|
||||||
let forward_cv = |xs, i: usize| {
|
let forward_cv = |xs, i: usize| {
|
||||||
let xs_2 = self.cv2[i].0.forward(xs)?;
|
let xs_2 = self.cv2[i].0.forward(xs)?;
|
||||||
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
|
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
|
||||||
@ -561,7 +573,7 @@ impl DetectionHead {
|
|||||||
let xs2 = forward_cv(xs2, 2)?;
|
let xs2 = forward_cv(xs2, 2)?;
|
||||||
|
|
||||||
let (anchors, strides) = make_anchors(&xs0, &xs1, &xs2, (8, 16, 32), 0.5)?;
|
let (anchors, strides) = make_anchors(&xs0, &xs1, &xs2, (8, 16, 32), 0.5)?;
|
||||||
let anchors = anchors.transpose(0, 1)?;
|
let anchors = anchors.transpose(0, 1)?.unsqueeze(0)?;
|
||||||
let strides = strides.transpose(0, 1)?;
|
let strides = strides.transpose(0, 1)?;
|
||||||
|
|
||||||
let reshape = |xs: &Tensor| {
|
let reshape = |xs: &Tensor| {
|
||||||
@ -577,9 +589,70 @@ impl DetectionHead {
|
|||||||
let box_ = x_cat.i((.., ..self.ch * 4))?;
|
let box_ = x_cat.i((.., ..self.ch * 4))?;
|
||||||
let cls = x_cat.i((.., self.ch * 4..))?;
|
let cls = x_cat.i((.., self.ch * 4..))?;
|
||||||
|
|
||||||
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?;
|
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors)?;
|
||||||
let dbox = dbox.broadcast_mul(&strides)?;
|
let dbox = dbox.broadcast_mul(&strides)?;
|
||||||
Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)
|
let pred = Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1)?;
|
||||||
|
Ok(DetectionHeadOut {
|
||||||
|
pred,
|
||||||
|
anchors,
|
||||||
|
strides,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PoseHead {
|
||||||
|
// kpt: keypoints, (17, 3)
|
||||||
|
// nc: num-classes, 80
|
||||||
|
fn load(
|
||||||
|
vb: VarBuilder,
|
||||||
|
nc: usize,
|
||||||
|
kpt: (usize, usize),
|
||||||
|
filters: (usize, usize, usize),
|
||||||
|
) -> Result<Self> {
|
||||||
|
let detect = DetectionHead::load(vb.clone(), nc, filters)?;
|
||||||
|
let nk = kpt.0 * kpt.1;
|
||||||
|
let c4 = usize::max(filters.0 / 4, nk);
|
||||||
|
let cv4 = [
|
||||||
|
Self::load_cv4(vb.pp("cv4.0"), c4, nk, filters.0)?,
|
||||||
|
Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?,
|
||||||
|
Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?,
|
||||||
|
];
|
||||||
|
Ok(Self { detect, cv4, kpt })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_cv4(
|
||||||
|
vb: VarBuilder,
|
||||||
|
c1: usize,
|
||||||
|
nc: usize,
|
||||||
|
filter: usize,
|
||||||
|
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
|
||||||
|
let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?;
|
||||||
|
let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?;
|
||||||
|
let conv = conv2d(c1, nc, 1, Default::default(), vb.pp("2"))?;
|
||||||
|
Ok((block0, block1, conv))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
|
||||||
|
let d = self.detect.forward(xs0, xs1, xs2)?;
|
||||||
|
let forward_cv = |xs: &Tensor, i: usize| {
|
||||||
|
let (b_sz, _, h, w) = xs.dims4()?;
|
||||||
|
let xs = self.cv4[i].0.forward(xs)?;
|
||||||
|
let xs = self.cv4[i].1.forward(&xs)?;
|
||||||
|
let xs = self.cv4[i].2.forward(&xs)?;
|
||||||
|
xs.reshape((b_sz, self.kpt.0 * self.kpt.1, h * w))
|
||||||
|
};
|
||||||
|
let xs0 = forward_cv(xs0, 0)?;
|
||||||
|
let xs1 = forward_cv(xs1, 1)?;
|
||||||
|
let xs2 = forward_cv(xs2, 2)?;
|
||||||
|
let xs = Tensor::cat(&[xs0, xs1, xs2], D::Minus1)?;
|
||||||
|
let (b_sz, _nk, hw) = xs.dims3()?;
|
||||||
|
let xs = xs.reshape((b_sz, self.kpt.0, self.kpt.1, hw))?;
|
||||||
|
|
||||||
|
let ys01 = ((xs.i((.., .., 0..2))? * 2.)?.broadcast_add(&d.anchors)? - 0.5)?
|
||||||
|
.broadcast_mul(&d.strides)?;
|
||||||
|
let ys2 = candle_nn::ops::sigmoid(&xs.i((.., .., 2..3))?)?;
|
||||||
|
let ys = Tensor::cat(&[ys01, ys2], 2)?.flatten(1, 2)?;
|
||||||
|
Tensor::cat(&[d.pred, ys], 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -600,6 +673,31 @@ impl YoloV8 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Module for YoloV8 {
|
impl Module for YoloV8 {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
||||||
|
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
||||||
|
Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct YoloV8Pose {
|
||||||
|
net: DarkNet,
|
||||||
|
fpn: YoloV8Neck,
|
||||||
|
head: PoseHead,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
impl YoloV8Pose {
|
||||||
|
fn load(vb: VarBuilder, m: Multiples, num_classes: usize, kpt: (usize, usize)) -> Result<Self> {
|
||||||
|
let net = DarkNet::load(vb.pp("net"), m)?;
|
||||||
|
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
|
||||||
|
let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?;
|
||||||
|
Ok(Self { net, fpn, head })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for YoloV8Pose {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
||||||
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
||||||
@ -757,6 +855,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||||
let model = YoloV8::load(vb, multiples, /* num_classes=*/ 80)?;
|
let model = YoloV8::load(vb, multiples, /* num_classes=*/ 80)?;
|
||||||
|
// let model = YoloV8Pose::load(vb, multiples, /* num_classes=*/ 1, (17, 3))?;
|
||||||
println!("model loaded");
|
println!("model loaded");
|
||||||
for image_name in args.images.iter() {
|
for image_name in args.images.iter() {
|
||||||
println!("processing {image_name}");
|
println!("processing {image_name}");
|
||||||
@ -790,7 +889,6 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
)?
|
)?
|
||||||
.permute((2, 0, 1))?
|
.permute((2, 0, 1))?
|
||||||
};
|
};
|
||||||
println!("{image_t:?}");
|
|
||||||
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||||
let predictions = model.forward(&image_t)?.squeeze(0)?;
|
let predictions = model.forward(&image_t)?.squeeze(0)?;
|
||||||
println!("generated predictions {predictions:?}");
|
println!("generated predictions {predictions:?}");
|
||||||
|
Reference in New Issue
Block a user