mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add tracing. (#943)
This commit is contained in:
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
||||
mod model;
|
||||
use model::{Multiples, YoloV8, YoloV8Pose};
|
||||
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
||||
use clap::{Parser, ValueEnum};
|
||||
@ -253,6 +253,14 @@ enum YoloTask {
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Model weights, in safetensors format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
@ -363,6 +371,7 @@ impl Task for YoloV8Pose {
|
||||
}
|
||||
|
||||
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
// Create the model and load the weights from the file.
|
||||
let multiples = match args.which {
|
||||
Which::N => Multiples::n(),
|
||||
@ -374,7 +383,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
let model = args.model()?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let model = T::load(vb, multiples)?;
|
||||
println!("model loaded");
|
||||
for image_name in args.images.iter() {
|
||||
@ -405,7 +414,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
Tensor::from_vec(
|
||||
data,
|
||||
(img.height() as usize, img.width() as usize, 3),
|
||||
&Device::Cpu,
|
||||
&device,
|
||||
)?
|
||||
.permute((2, 0, 1))?
|
||||
};
|
||||
@ -430,7 +439,19 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
match args.task {
|
||||
YoloTask::Detect => run::<YoloV8>(args)?,
|
||||
YoloTask::Pose => run::<YoloV8Pose>(args)?,
|
||||
|
Reference in New Issue
Block a user