From 5dbe46b389da4ba39131ce34752249cae640ad9e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 23 Sep 2023 16:55:46 +0100 Subject: [PATCH] Add tracing. (#943) --- candle-examples/examples/yolo-v8/main.rs | 27 ++++++++- candle-examples/examples/yolo-v8/model.rs | 73 ++++++++++++++++++++--- 2 files changed, 90 insertions(+), 10 deletions(-) diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index d48bac35..dc709db4 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -7,7 +7,7 @@ extern crate accelerate_src; mod model; use model::{Multiples, YoloV8, YoloV8Pose}; -use candle::{DType, Device, IndexOp, Result, Tensor}; +use candle::{DType, IndexOp, Result, Tensor}; use candle_nn::{Module, VarBuilder}; use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use clap::{Parser, ValueEnum}; @@ -253,6 +253,14 @@ enum YoloTask { #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] pub struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + /// Model weights, in safetensors format. #[arg(long)] model: Option, @@ -363,6 +371,7 @@ impl Task for YoloV8Pose { } pub fn run(args: Args) -> anyhow::Result<()> { + let device = candle_examples::device(args.cpu)?; // Create the model and load the weights from the file. let multiples = match args.which { Which::N => Multiples::n(), @@ -374,7 +383,7 @@ pub fn run(args: Args) -> anyhow::Result<()> { let model = args.model()?; let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu); + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let model = T::load(vb, multiples)?; println!("model loaded"); for image_name in args.images.iter() { @@ -405,7 +414,7 @@ pub fn run(args: Args) -> anyhow::Result<()> { Tensor::from_vec( data, (img.height() as usize, img.width() as usize, 3), - &Device::Cpu, + &device, )? .permute((2, 0, 1))? }; @@ -430,7 +439,19 @@ pub fn run(args: Args) -> anyhow::Result<()> { } pub fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + let args = Args::parse(); + + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + match args.task { YoloTask::Detect => run::(args)?, YoloTask::Pose => run::(args)?, diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs index b834f967..bf48fd84 100644 --- a/candle-examples/examples/yolo-v8/model.rs +++ b/candle-examples/examples/yolo-v8/model.rs @@ -77,6 +77,7 @@ impl Module for Upsample { struct ConvBlock { conv: Conv2d, bn: BatchNorm, + span: tracing::Span, } impl ConvBlock { @@ -97,12 +98,17 @@ impl ConvBlock { }; let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; - Ok(Self { conv, bn }) + Ok(Self { + conv, + bn, + span: tracing::span!(tracing::Level::TRACE, "conv-block"), + }) } } impl Module for ConvBlock { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let xs = self.conv.forward(xs)?; let xs = self.bn.forward(&xs)?; candle_nn::ops::silu(&xs) @@ -114,6 +120,7 @@ struct Bottleneck { cv1: ConvBlock, cv2: ConvBlock, residual: bool, + span: tracing::Span, } impl Bottleneck { @@ -123,12 +130,18 @@ impl Bottleneck { let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 3, 1, None)?; let cv2 = ConvBlock::load(vb.pp("cv2"), c_, c2, 3, 1, None)?; let residual = c1 == c2 && shortcut; - Ok(Self { cv1, cv2, residual }) + Ok(Self { + cv1, + cv2, + residual, + span: tracing::span!(tracing::Level::TRACE, "bottleneck"), + }) } } impl Module for Bottleneck { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let ys = self.cv2.forward(&self.cv1.forward(xs)?)?; if self.residual { xs + ys @@ -143,6 +156,7 @@ struct C2f { cv1: ConvBlock, cv2: ConvBlock, bottleneck: Vec, + span: tracing::Span, } impl C2f { @@ -159,12 +173,14 @@ impl C2f { cv1, cv2, bottleneck, + span: tracing::span!(tracing::Level::TRACE, "c2f"), }) } } impl Module for C2f { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let ys = self.cv1.forward(xs)?; let mut ys = ys.chunk(2, 1)?; for m in self.bottleneck.iter() { @@ -180,6 +196,7 @@ struct Sppf { cv1: ConvBlock, cv2: ConvBlock, k: usize, + span: tracing::Span, } impl Sppf { @@ -187,12 +204,18 @@ impl Sppf { let c_ = c1 / 2; let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 1, 1, None)?; let cv2 = ConvBlock::load(vb.pp("cv2"), c_ * 4, c2, 1, 1, None)?; - Ok(Self { cv1, cv2, k }) + Ok(Self { + cv1, + cv2, + k, + span: tracing::span!(tracing::Level::TRACE, "sppf"), + }) } } impl Module for Sppf { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let (_, _, _, _) = xs.dims4()?; let xs = self.cv1.forward(xs)?; let xs2 = xs @@ -215,17 +238,23 @@ impl Module for Sppf { struct Dfl { conv: Conv2d, num_classes: usize, + span: tracing::Span, } impl Dfl { fn load(vb: VarBuilder, num_classes: usize) -> Result { let conv = conv2d_no_bias(num_classes, 1, 1, Default::default(), vb.pp("conv"))?; - Ok(Self { conv, num_classes }) + Ok(Self { + conv, + num_classes, + span: tracing::span!(tracing::Level::TRACE, "dfl"), + }) } } impl Module for Dfl { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let (b_sz, _channels, anchors) = xs.dims3()?; let xs = xs .reshape((b_sz, 4, self.num_classes, anchors))? @@ -247,6 +276,7 @@ struct DarkNet { b4_0: ConvBlock, b4_1: C2f, b5: Sppf, + span: tracing::Span, } impl DarkNet { @@ -330,10 +360,12 @@ impl DarkNet { b4_0, b4_1, b5, + span: tracing::span!(tracing::Level::TRACE, "darknet"), }) } fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { + let _enter = self.span.enter(); let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?; let x2 = self .b2_2 @@ -354,6 +386,7 @@ struct YoloV8Neck { n4: C2f, n5: ConvBlock, n6: C2f, + span: tracing::Span, } impl YoloV8Neck { @@ -413,10 +446,12 @@ impl YoloV8Neck { n4, n5, n6, + span: tracing::span!(tracing::Level::TRACE, "neck"), }) } fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { + let _enter = self.span.enter(); let x = self .n1 .forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?; @@ -440,6 +475,7 @@ struct DetectionHead { cv3: [(ConvBlock, ConvBlock, Conv2d); 3], ch: usize, no: usize, + span: tracing::Span, } #[derive(Debug)] @@ -447,6 +483,7 @@ struct PoseHead { detect: DetectionHead, cv4: [(ConvBlock, ConvBlock, Conv2d); 3], kpt: (usize, usize), + span: tracing::Span, } fn make_anchors( @@ -519,6 +556,7 @@ impl DetectionHead { cv3, ch, no, + span: tracing::span!(tracing::Level::TRACE, "detection-head"), }) } @@ -547,6 +585,7 @@ impl DetectionHead { } fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result { + let _enter = self.span.enter(); let forward_cv = |xs, i: usize| { let xs_2 = self.cv2[i].0.forward(xs)?; let xs_2 = self.cv2[i].1.forward(&xs_2)?; @@ -606,7 +645,12 @@ impl PoseHead { Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?, Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?, ]; - Ok(Self { detect, cv4, kpt }) + Ok(Self { + detect, + cv4, + kpt, + span: tracing::span!(tracing::Level::TRACE, "pose-head"), + }) } fn load_cv4( @@ -622,6 +666,7 @@ impl PoseHead { } fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result { + let _enter = self.span.enter(); let d = self.detect.forward(xs0, xs1, xs2)?; let forward_cv = |xs: &Tensor, i: usize| { let (b_sz, _, h, w) = xs.dims4()?; @@ -650,6 +695,7 @@ pub struct YoloV8 { net: DarkNet, fpn: YoloV8Neck, head: DetectionHead, + span: tracing::Span, } impl YoloV8 { @@ -657,12 +703,18 @@ impl YoloV8 { let net = DarkNet::load(vb.pp("net"), m)?; let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?; let head = DetectionHead::load(vb.pp("head"), num_classes, m.filters())?; - Ok(Self { net, fpn, head }) + Ok(Self { + net, + fpn, + head, + span: tracing::span!(tracing::Level::TRACE, "yolo-v8"), + }) } } impl Module for YoloV8 { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let (xs1, xs2, xs3) = self.net.forward(xs)?; let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?; Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred) @@ -674,6 +726,7 @@ pub struct YoloV8Pose { net: DarkNet, fpn: YoloV8Neck, head: PoseHead, + span: tracing::Span, } impl YoloV8Pose { @@ -686,12 +739,18 @@ impl YoloV8Pose { let net = DarkNet::load(vb.pp("net"), m)?; let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?; let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?; - Ok(Self { net, fpn, head }) + Ok(Self { + net, + fpn, + head, + span: tracing::span!(tracing::Level::TRACE, "yolo-v8-pose"), + }) } } impl Module for YoloV8Pose { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let (xs1, xs2, xs3) = self.net.forward(xs)?; let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?; self.head.forward(&xs1, &xs2, &xs3)