Add tracing. (#943)

This commit is contained in:
Laurent Mazare
2023-09-23 16:55:46 +01:00
committed by GitHub
parent ccf352f3d1
commit 5dbe46b389
2 changed files with 90 additions and 10 deletions

View File

@ -7,7 +7,7 @@ extern crate accelerate_src;
mod model; mod model;
use model::{Multiples, YoloV8, YoloV8Pose}; use model::{Multiples, YoloV8, YoloV8Pose};
use candle::{DType, Device, IndexOp, Result, Tensor}; use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder}; use candle_nn::{Module, VarBuilder};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
@ -253,6 +253,14 @@ enum YoloTask {
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
pub struct Args { pub struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Model weights, in safetensors format. /// Model weights, in safetensors format.
#[arg(long)] #[arg(long)]
model: Option<String>, model: Option<String>,
@ -363,6 +371,7 @@ impl Task for YoloV8Pose {
} }
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> { pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
// Create the model and load the weights from the file. // Create the model and load the weights from the file.
let multiples = match args.which { let multiples = match args.which {
Which::N => Multiples::n(), Which::N => Multiples::n(),
@ -374,7 +383,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
let model = args.model()?; let model = args.model()?;
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?; let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu); let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let model = T::load(vb, multiples)?; let model = T::load(vb, multiples)?;
println!("model loaded"); println!("model loaded");
for image_name in args.images.iter() { for image_name in args.images.iter() {
@ -405,7 +414,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
Tensor::from_vec( Tensor::from_vec(
data, data,
(img.height() as usize, img.width() as usize, 3), (img.height() as usize, img.width() as usize, 3),
&Device::Cpu, &device,
)? )?
.permute((2, 0, 1))? .permute((2, 0, 1))?
}; };
@ -430,7 +439,19 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
} }
pub fn main() -> anyhow::Result<()> { pub fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse(); let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
match args.task { match args.task {
YoloTask::Detect => run::<YoloV8>(args)?, YoloTask::Detect => run::<YoloV8>(args)?,
YoloTask::Pose => run::<YoloV8Pose>(args)?, YoloTask::Pose => run::<YoloV8Pose>(args)?,

View File

@ -77,6 +77,7 @@ impl Module for Upsample {
struct ConvBlock { struct ConvBlock {
conv: Conv2d, conv: Conv2d,
bn: BatchNorm, bn: BatchNorm,
span: tracing::Span,
} }
impl ConvBlock { impl ConvBlock {
@ -97,12 +98,17 @@ impl ConvBlock {
}; };
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
Ok(Self { conv, bn }) Ok(Self {
conv,
bn,
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
})
} }
} }
impl Module for ConvBlock { impl Module for ConvBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.conv.forward(xs)?; let xs = self.conv.forward(xs)?;
let xs = self.bn.forward(&xs)?; let xs = self.bn.forward(&xs)?;
candle_nn::ops::silu(&xs) candle_nn::ops::silu(&xs)
@ -114,6 +120,7 @@ struct Bottleneck {
cv1: ConvBlock, cv1: ConvBlock,
cv2: ConvBlock, cv2: ConvBlock,
residual: bool, residual: bool,
span: tracing::Span,
} }
impl Bottleneck { impl Bottleneck {
@ -123,12 +130,18 @@ impl Bottleneck {
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 3, 1, None)?; let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 3, 1, None)?;
let cv2 = ConvBlock::load(vb.pp("cv2"), c_, c2, 3, 1, None)?; let cv2 = ConvBlock::load(vb.pp("cv2"), c_, c2, 3, 1, None)?;
let residual = c1 == c2 && shortcut; let residual = c1 == c2 && shortcut;
Ok(Self { cv1, cv2, residual }) Ok(Self {
cv1,
cv2,
residual,
span: tracing::span!(tracing::Level::TRACE, "bottleneck"),
})
} }
} }
impl Module for Bottleneck { impl Module for Bottleneck {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let ys = self.cv2.forward(&self.cv1.forward(xs)?)?; let ys = self.cv2.forward(&self.cv1.forward(xs)?)?;
if self.residual { if self.residual {
xs + ys xs + ys
@ -143,6 +156,7 @@ struct C2f {
cv1: ConvBlock, cv1: ConvBlock,
cv2: ConvBlock, cv2: ConvBlock,
bottleneck: Vec<Bottleneck>, bottleneck: Vec<Bottleneck>,
span: tracing::Span,
} }
impl C2f { impl C2f {
@ -159,12 +173,14 @@ impl C2f {
cv1, cv1,
cv2, cv2,
bottleneck, bottleneck,
span: tracing::span!(tracing::Level::TRACE, "c2f"),
}) })
} }
} }
impl Module for C2f { impl Module for C2f {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let ys = self.cv1.forward(xs)?; let ys = self.cv1.forward(xs)?;
let mut ys = ys.chunk(2, 1)?; let mut ys = ys.chunk(2, 1)?;
for m in self.bottleneck.iter() { for m in self.bottleneck.iter() {
@ -180,6 +196,7 @@ struct Sppf {
cv1: ConvBlock, cv1: ConvBlock,
cv2: ConvBlock, cv2: ConvBlock,
k: usize, k: usize,
span: tracing::Span,
} }
impl Sppf { impl Sppf {
@ -187,12 +204,18 @@ impl Sppf {
let c_ = c1 / 2; let c_ = c1 / 2;
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 1, 1, None)?; let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 1, 1, None)?;
let cv2 = ConvBlock::load(vb.pp("cv2"), c_ * 4, c2, 1, 1, None)?; let cv2 = ConvBlock::load(vb.pp("cv2"), c_ * 4, c2, 1, 1, None)?;
Ok(Self { cv1, cv2, k }) Ok(Self {
cv1,
cv2,
k,
span: tracing::span!(tracing::Level::TRACE, "sppf"),
})
} }
} }
impl Module for Sppf { impl Module for Sppf {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (_, _, _, _) = xs.dims4()?; let (_, _, _, _) = xs.dims4()?;
let xs = self.cv1.forward(xs)?; let xs = self.cv1.forward(xs)?;
let xs2 = xs let xs2 = xs
@ -215,17 +238,23 @@ impl Module for Sppf {
struct Dfl { struct Dfl {
conv: Conv2d, conv: Conv2d,
num_classes: usize, num_classes: usize,
span: tracing::Span,
} }
impl Dfl { impl Dfl {
fn load(vb: VarBuilder, num_classes: usize) -> Result<Self> { fn load(vb: VarBuilder, num_classes: usize) -> Result<Self> {
let conv = conv2d_no_bias(num_classes, 1, 1, Default::default(), vb.pp("conv"))?; let conv = conv2d_no_bias(num_classes, 1, 1, Default::default(), vb.pp("conv"))?;
Ok(Self { conv, num_classes }) Ok(Self {
conv,
num_classes,
span: tracing::span!(tracing::Level::TRACE, "dfl"),
})
} }
} }
impl Module for Dfl { impl Module for Dfl {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, _channels, anchors) = xs.dims3()?; let (b_sz, _channels, anchors) = xs.dims3()?;
let xs = xs let xs = xs
.reshape((b_sz, 4, self.num_classes, anchors))? .reshape((b_sz, 4, self.num_classes, anchors))?
@ -247,6 +276,7 @@ struct DarkNet {
b4_0: ConvBlock, b4_0: ConvBlock,
b4_1: C2f, b4_1: C2f,
b5: Sppf, b5: Sppf,
span: tracing::Span,
} }
impl DarkNet { impl DarkNet {
@ -330,10 +360,12 @@ impl DarkNet {
b4_0, b4_0,
b4_1, b4_1,
b5, b5,
span: tracing::span!(tracing::Level::TRACE, "darknet"),
}) })
} }
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let _enter = self.span.enter();
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?; let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
let x2 = self let x2 = self
.b2_2 .b2_2
@ -354,6 +386,7 @@ struct YoloV8Neck {
n4: C2f, n4: C2f,
n5: ConvBlock, n5: ConvBlock,
n6: C2f, n6: C2f,
span: tracing::Span,
} }
impl YoloV8Neck { impl YoloV8Neck {
@ -413,10 +446,12 @@ impl YoloV8Neck {
n4, n4,
n5, n5,
n6, n6,
span: tracing::span!(tracing::Level::TRACE, "neck"),
}) })
} }
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let _enter = self.span.enter();
let x = self let x = self
.n1 .n1
.forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?; .forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?;
@ -440,6 +475,7 @@ struct DetectionHead {
cv3: [(ConvBlock, ConvBlock, Conv2d); 3], cv3: [(ConvBlock, ConvBlock, Conv2d); 3],
ch: usize, ch: usize,
no: usize, no: usize,
span: tracing::Span,
} }
#[derive(Debug)] #[derive(Debug)]
@ -447,6 +483,7 @@ struct PoseHead {
detect: DetectionHead, detect: DetectionHead,
cv4: [(ConvBlock, ConvBlock, Conv2d); 3], cv4: [(ConvBlock, ConvBlock, Conv2d); 3],
kpt: (usize, usize), kpt: (usize, usize),
span: tracing::Span,
} }
fn make_anchors( fn make_anchors(
@ -519,6 +556,7 @@ impl DetectionHead {
cv3, cv3,
ch, ch,
no, no,
span: tracing::span!(tracing::Level::TRACE, "detection-head"),
}) })
} }
@ -547,6 +585,7 @@ impl DetectionHead {
} }
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> { fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> {
let _enter = self.span.enter();
let forward_cv = |xs, i: usize| { let forward_cv = |xs, i: usize| {
let xs_2 = self.cv2[i].0.forward(xs)?; let xs_2 = self.cv2[i].0.forward(xs)?;
let xs_2 = self.cv2[i].1.forward(&xs_2)?; let xs_2 = self.cv2[i].1.forward(&xs_2)?;
@ -606,7 +645,12 @@ impl PoseHead {
Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?, Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?,
Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?, Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?,
]; ];
Ok(Self { detect, cv4, kpt }) Ok(Self {
detect,
cv4,
kpt,
span: tracing::span!(tracing::Level::TRACE, "pose-head"),
})
} }
fn load_cv4( fn load_cv4(
@ -622,6 +666,7 @@ impl PoseHead {
} }
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> { fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let d = self.detect.forward(xs0, xs1, xs2)?; let d = self.detect.forward(xs0, xs1, xs2)?;
let forward_cv = |xs: &Tensor, i: usize| { let forward_cv = |xs: &Tensor, i: usize| {
let (b_sz, _, h, w) = xs.dims4()?; let (b_sz, _, h, w) = xs.dims4()?;
@ -650,6 +695,7 @@ pub struct YoloV8 {
net: DarkNet, net: DarkNet,
fpn: YoloV8Neck, fpn: YoloV8Neck,
head: DetectionHead, head: DetectionHead,
span: tracing::Span,
} }
impl YoloV8 { impl YoloV8 {
@ -657,12 +703,18 @@ impl YoloV8 {
let net = DarkNet::load(vb.pp("net"), m)?; let net = DarkNet::load(vb.pp("net"), m)?;
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?; let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
let head = DetectionHead::load(vb.pp("head"), num_classes, m.filters())?; let head = DetectionHead::load(vb.pp("head"), num_classes, m.filters())?;
Ok(Self { net, fpn, head }) Ok(Self {
net,
fpn,
head,
span: tracing::span!(tracing::Level::TRACE, "yolo-v8"),
})
} }
} }
impl Module for YoloV8 { impl Module for YoloV8 {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (xs1, xs2, xs3) = self.net.forward(xs)?; let (xs1, xs2, xs3) = self.net.forward(xs)?;
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?; let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred) Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred)
@ -674,6 +726,7 @@ pub struct YoloV8Pose {
net: DarkNet, net: DarkNet,
fpn: YoloV8Neck, fpn: YoloV8Neck,
head: PoseHead, head: PoseHead,
span: tracing::Span,
} }
impl YoloV8Pose { impl YoloV8Pose {
@ -686,12 +739,18 @@ impl YoloV8Pose {
let net = DarkNet::load(vb.pp("net"), m)?; let net = DarkNet::load(vb.pp("net"), m)?;
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?; let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?; let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?;
Ok(Self { net, fpn, head }) Ok(Self {
net,
fpn,
head,
span: tracing::span!(tracing::Level::TRACE, "yolo-v8-pose"),
})
} }
} }
impl Module for YoloV8Pose { impl Module for YoloV8Pose {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (xs1, xs2, xs3) = self.net.forward(xs)?; let (xs1, xs2, xs3) = self.net.forward(xs)?;
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?; let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
self.head.forward(&xs1, &xs2, &xs3) self.head.forward(&xs1, &xs2, &xs3)