Add tracing. (#943)

This commit is contained in:
Laurent Mazare
2023-09-23 16:55:46 +01:00
committed by GitHub
parent ccf352f3d1
commit 5dbe46b389
2 changed files with 90 additions and 10 deletions

View File

@ -7,7 +7,7 @@ extern crate accelerate_src;
mod model;
use model::{Multiples, YoloV8, YoloV8Pose};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum};
@ -253,6 +253,14 @@ enum YoloTask {
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Model weights, in safetensors format.
#[arg(long)]
model: Option<String>,
@ -363,6 +371,7 @@ impl Task for YoloV8Pose {
}
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
// Create the model and load the weights from the file.
let multiples = match args.which {
Which::N => Multiples::n(),
@ -374,7 +383,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
let model = args.model()?;
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let model = T::load(vb, multiples)?;
println!("model loaded");
for image_name in args.images.iter() {
@ -405,7 +414,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
Tensor::from_vec(
data,
(img.height() as usize, img.width() as usize, 3),
&Device::Cpu,
&device,
)?
.permute((2, 0, 1))?
};
@ -430,7 +439,19 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
}
pub fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
match args.task {
YoloTask::Detect => run::<YoloV8>(args)?,
YoloTask::Pose => run::<YoloV8Pose>(args)?,