mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add tracing. (#943)
This commit is contained in:
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
|||||||
mod model;
|
mod model;
|
||||||
use model::{Multiples, YoloV8, YoloV8Pose};
|
use model::{Multiples, YoloV8, YoloV8Pose};
|
||||||
|
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
use candle::{DType, IndexOp, Result, Tensor};
|
||||||
use candle_nn::{Module, VarBuilder};
|
use candle_nn::{Module, VarBuilder};
|
||||||
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
@ -253,6 +253,14 @@ enum YoloTask {
|
|||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
pub struct Args {
|
pub struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
/// Model weights, in safetensors format.
|
/// Model weights, in safetensors format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
@ -363,6 +371,7 @@ impl Task for YoloV8Pose {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
// Create the model and load the weights from the file.
|
// Create the model and load the weights from the file.
|
||||||
let multiples = match args.which {
|
let multiples = match args.which {
|
||||||
Which::N => Multiples::n(),
|
Which::N => Multiples::n(),
|
||||||
@ -374,7 +383,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
|||||||
let model = args.model()?;
|
let model = args.model()?;
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||||
let model = T::load(vb, multiples)?;
|
let model = T::load(vb, multiples)?;
|
||||||
println!("model loaded");
|
println!("model loaded");
|
||||||
for image_name in args.images.iter() {
|
for image_name in args.images.iter() {
|
||||||
@ -405,7 +414,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
|||||||
Tensor::from_vec(
|
Tensor::from_vec(
|
||||||
data,
|
data,
|
||||||
(img.height() as usize, img.width() as usize, 3),
|
(img.height() as usize, img.width() as usize, 3),
|
||||||
&Device::Cpu,
|
&device,
|
||||||
)?
|
)?
|
||||||
.permute((2, 0, 1))?
|
.permute((2, 0, 1))?
|
||||||
};
|
};
|
||||||
@ -430,7 +439,19 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
match args.task {
|
match args.task {
|
||||||
YoloTask::Detect => run::<YoloV8>(args)?,
|
YoloTask::Detect => run::<YoloV8>(args)?,
|
||||||
YoloTask::Pose => run::<YoloV8Pose>(args)?,
|
YoloTask::Pose => run::<YoloV8Pose>(args)?,
|
||||||
|
@ -77,6 +77,7 @@ impl Module for Upsample {
|
|||||||
struct ConvBlock {
|
struct ConvBlock {
|
||||||
conv: Conv2d,
|
conv: Conv2d,
|
||||||
bn: BatchNorm,
|
bn: BatchNorm,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConvBlock {
|
impl ConvBlock {
|
||||||
@ -97,12 +98,17 @@ impl ConvBlock {
|
|||||||
};
|
};
|
||||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||||
Ok(Self { conv, bn })
|
Ok(Self {
|
||||||
|
conv,
|
||||||
|
bn,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for ConvBlock {
|
impl Module for ConvBlock {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let xs = self.conv.forward(xs)?;
|
let xs = self.conv.forward(xs)?;
|
||||||
let xs = self.bn.forward(&xs)?;
|
let xs = self.bn.forward(&xs)?;
|
||||||
candle_nn::ops::silu(&xs)
|
candle_nn::ops::silu(&xs)
|
||||||
@ -114,6 +120,7 @@ struct Bottleneck {
|
|||||||
cv1: ConvBlock,
|
cv1: ConvBlock,
|
||||||
cv2: ConvBlock,
|
cv2: ConvBlock,
|
||||||
residual: bool,
|
residual: bool,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Bottleneck {
|
impl Bottleneck {
|
||||||
@ -123,12 +130,18 @@ impl Bottleneck {
|
|||||||
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 3, 1, None)?;
|
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 3, 1, None)?;
|
||||||
let cv2 = ConvBlock::load(vb.pp("cv2"), c_, c2, 3, 1, None)?;
|
let cv2 = ConvBlock::load(vb.pp("cv2"), c_, c2, 3, 1, None)?;
|
||||||
let residual = c1 == c2 && shortcut;
|
let residual = c1 == c2 && shortcut;
|
||||||
Ok(Self { cv1, cv2, residual })
|
Ok(Self {
|
||||||
|
cv1,
|
||||||
|
cv2,
|
||||||
|
residual,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "bottleneck"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Bottleneck {
|
impl Module for Bottleneck {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let ys = self.cv2.forward(&self.cv1.forward(xs)?)?;
|
let ys = self.cv2.forward(&self.cv1.forward(xs)?)?;
|
||||||
if self.residual {
|
if self.residual {
|
||||||
xs + ys
|
xs + ys
|
||||||
@ -143,6 +156,7 @@ struct C2f {
|
|||||||
cv1: ConvBlock,
|
cv1: ConvBlock,
|
||||||
cv2: ConvBlock,
|
cv2: ConvBlock,
|
||||||
bottleneck: Vec<Bottleneck>,
|
bottleneck: Vec<Bottleneck>,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl C2f {
|
impl C2f {
|
||||||
@ -159,12 +173,14 @@ impl C2f {
|
|||||||
cv1,
|
cv1,
|
||||||
cv2,
|
cv2,
|
||||||
bottleneck,
|
bottleneck,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "c2f"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for C2f {
|
impl Module for C2f {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let ys = self.cv1.forward(xs)?;
|
let ys = self.cv1.forward(xs)?;
|
||||||
let mut ys = ys.chunk(2, 1)?;
|
let mut ys = ys.chunk(2, 1)?;
|
||||||
for m in self.bottleneck.iter() {
|
for m in self.bottleneck.iter() {
|
||||||
@ -180,6 +196,7 @@ struct Sppf {
|
|||||||
cv1: ConvBlock,
|
cv1: ConvBlock,
|
||||||
cv2: ConvBlock,
|
cv2: ConvBlock,
|
||||||
k: usize,
|
k: usize,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Sppf {
|
impl Sppf {
|
||||||
@ -187,12 +204,18 @@ impl Sppf {
|
|||||||
let c_ = c1 / 2;
|
let c_ = c1 / 2;
|
||||||
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 1, 1, None)?;
|
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 1, 1, None)?;
|
||||||
let cv2 = ConvBlock::load(vb.pp("cv2"), c_ * 4, c2, 1, 1, None)?;
|
let cv2 = ConvBlock::load(vb.pp("cv2"), c_ * 4, c2, 1, 1, None)?;
|
||||||
Ok(Self { cv1, cv2, k })
|
Ok(Self {
|
||||||
|
cv1,
|
||||||
|
cv2,
|
||||||
|
k,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "sppf"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Sppf {
|
impl Module for Sppf {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (_, _, _, _) = xs.dims4()?;
|
let (_, _, _, _) = xs.dims4()?;
|
||||||
let xs = self.cv1.forward(xs)?;
|
let xs = self.cv1.forward(xs)?;
|
||||||
let xs2 = xs
|
let xs2 = xs
|
||||||
@ -215,17 +238,23 @@ impl Module for Sppf {
|
|||||||
struct Dfl {
|
struct Dfl {
|
||||||
conv: Conv2d,
|
conv: Conv2d,
|
||||||
num_classes: usize,
|
num_classes: usize,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Dfl {
|
impl Dfl {
|
||||||
fn load(vb: VarBuilder, num_classes: usize) -> Result<Self> {
|
fn load(vb: VarBuilder, num_classes: usize) -> Result<Self> {
|
||||||
let conv = conv2d_no_bias(num_classes, 1, 1, Default::default(), vb.pp("conv"))?;
|
let conv = conv2d_no_bias(num_classes, 1, 1, Default::default(), vb.pp("conv"))?;
|
||||||
Ok(Self { conv, num_classes })
|
Ok(Self {
|
||||||
|
conv,
|
||||||
|
num_classes,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "dfl"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Dfl {
|
impl Module for Dfl {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (b_sz, _channels, anchors) = xs.dims3()?;
|
let (b_sz, _channels, anchors) = xs.dims3()?;
|
||||||
let xs = xs
|
let xs = xs
|
||||||
.reshape((b_sz, 4, self.num_classes, anchors))?
|
.reshape((b_sz, 4, self.num_classes, anchors))?
|
||||||
@ -247,6 +276,7 @@ struct DarkNet {
|
|||||||
b4_0: ConvBlock,
|
b4_0: ConvBlock,
|
||||||
b4_1: C2f,
|
b4_1: C2f,
|
||||||
b5: Sppf,
|
b5: Sppf,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DarkNet {
|
impl DarkNet {
|
||||||
@ -330,10 +360,12 @@ impl DarkNet {
|
|||||||
b4_0,
|
b4_0,
|
||||||
b4_1,
|
b4_1,
|
||||||
b5,
|
b5,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "darknet"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
|
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
|
||||||
let x2 = self
|
let x2 = self
|
||||||
.b2_2
|
.b2_2
|
||||||
@ -354,6 +386,7 @@ struct YoloV8Neck {
|
|||||||
n4: C2f,
|
n4: C2f,
|
||||||
n5: ConvBlock,
|
n5: ConvBlock,
|
||||||
n6: C2f,
|
n6: C2f,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl YoloV8Neck {
|
impl YoloV8Neck {
|
||||||
@ -413,10 +446,12 @@ impl YoloV8Neck {
|
|||||||
n4,
|
n4,
|
||||||
n5,
|
n5,
|
||||||
n6,
|
n6,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "neck"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let x = self
|
let x = self
|
||||||
.n1
|
.n1
|
||||||
.forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?;
|
.forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?;
|
||||||
@ -440,6 +475,7 @@ struct DetectionHead {
|
|||||||
cv3: [(ConvBlock, ConvBlock, Conv2d); 3],
|
cv3: [(ConvBlock, ConvBlock, Conv2d); 3],
|
||||||
ch: usize,
|
ch: usize,
|
||||||
no: usize,
|
no: usize,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -447,6 +483,7 @@ struct PoseHead {
|
|||||||
detect: DetectionHead,
|
detect: DetectionHead,
|
||||||
cv4: [(ConvBlock, ConvBlock, Conv2d); 3],
|
cv4: [(ConvBlock, ConvBlock, Conv2d); 3],
|
||||||
kpt: (usize, usize),
|
kpt: (usize, usize),
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_anchors(
|
fn make_anchors(
|
||||||
@ -519,6 +556,7 @@ impl DetectionHead {
|
|||||||
cv3,
|
cv3,
|
||||||
ch,
|
ch,
|
||||||
no,
|
no,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "detection-head"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -547,6 +585,7 @@ impl DetectionHead {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> {
|
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let forward_cv = |xs, i: usize| {
|
let forward_cv = |xs, i: usize| {
|
||||||
let xs_2 = self.cv2[i].0.forward(xs)?;
|
let xs_2 = self.cv2[i].0.forward(xs)?;
|
||||||
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
|
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
|
||||||
@ -606,7 +645,12 @@ impl PoseHead {
|
|||||||
Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?,
|
Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?,
|
||||||
Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?,
|
Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?,
|
||||||
];
|
];
|
||||||
Ok(Self { detect, cv4, kpt })
|
Ok(Self {
|
||||||
|
detect,
|
||||||
|
cv4,
|
||||||
|
kpt,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "pose-head"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_cv4(
|
fn load_cv4(
|
||||||
@ -622,6 +666,7 @@ impl PoseHead {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let d = self.detect.forward(xs0, xs1, xs2)?;
|
let d = self.detect.forward(xs0, xs1, xs2)?;
|
||||||
let forward_cv = |xs: &Tensor, i: usize| {
|
let forward_cv = |xs: &Tensor, i: usize| {
|
||||||
let (b_sz, _, h, w) = xs.dims4()?;
|
let (b_sz, _, h, w) = xs.dims4()?;
|
||||||
@ -650,6 +695,7 @@ pub struct YoloV8 {
|
|||||||
net: DarkNet,
|
net: DarkNet,
|
||||||
fpn: YoloV8Neck,
|
fpn: YoloV8Neck,
|
||||||
head: DetectionHead,
|
head: DetectionHead,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl YoloV8 {
|
impl YoloV8 {
|
||||||
@ -657,12 +703,18 @@ impl YoloV8 {
|
|||||||
let net = DarkNet::load(vb.pp("net"), m)?;
|
let net = DarkNet::load(vb.pp("net"), m)?;
|
||||||
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
|
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
|
||||||
let head = DetectionHead::load(vb.pp("head"), num_classes, m.filters())?;
|
let head = DetectionHead::load(vb.pp("head"), num_classes, m.filters())?;
|
||||||
Ok(Self { net, fpn, head })
|
Ok(Self {
|
||||||
|
net,
|
||||||
|
fpn,
|
||||||
|
head,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "yolo-v8"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for YoloV8 {
|
impl Module for YoloV8 {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
||||||
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
||||||
Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred)
|
Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred)
|
||||||
@ -674,6 +726,7 @@ pub struct YoloV8Pose {
|
|||||||
net: DarkNet,
|
net: DarkNet,
|
||||||
fpn: YoloV8Neck,
|
fpn: YoloV8Neck,
|
||||||
head: PoseHead,
|
head: PoseHead,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl YoloV8Pose {
|
impl YoloV8Pose {
|
||||||
@ -686,12 +739,18 @@ impl YoloV8Pose {
|
|||||||
let net = DarkNet::load(vb.pp("net"), m)?;
|
let net = DarkNet::load(vb.pp("net"), m)?;
|
||||||
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
|
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
|
||||||
let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?;
|
let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?;
|
||||||
Ok(Self { net, fpn, head })
|
Ok(Self {
|
||||||
|
net,
|
||||||
|
fpn,
|
||||||
|
head,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "yolo-v8-pose"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for YoloV8Pose {
|
impl Module for YoloV8Pose {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
||||||
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
||||||
self.head.forward(&xs1, &xs2, &xs3)
|
self.head.forward(&xs1, &xs2, &xs3)
|
||||||
|
Reference in New Issue
Block a user