diff --git a/Cargo.toml b/Cargo.toml index 89b3c63a..f1ad66eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ num-traits = "0.2.15" rand = "0.8.5" rand_distr = "0.4.3" rayon = "1.7.0" +rusttype = { version = "0.9", default-features = false } safetensors = "0.3.1" serde = { version = "1.0.171", features = ["derive"] } serde_json = "1.0.99" diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 731052ea..8383ab37 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -33,6 +33,7 @@ hf-hub = { workspace = true, features=["tokio"]} imageproc = { workspace = true } memmap2 = { workspace = true } rand = { workspace = true } +rusttype = { workspace = true } tokenizers = { workspace = true, features = ["onig"] } tracing = { workspace = true } tracing-chrome = { workspace = true } diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index 1a378680..ab047304 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -59,6 +59,7 @@ pub fn report_detect( h: usize, confidence_threshold: f32, nms_threshold: f32, + legend_size: u32, ) -> Result { let (pred_size, npreds) = pred.dims2()?; let nclasses = pred_size - 4; @@ -96,6 +97,8 @@ pub fn report_detect( let w_ratio = initial_w as f32 / w as f32; let h_ratio = initial_h as f32 / h as f32; let mut img = img.to_rgb8(); + let font = Vec::from(include_bytes!("roboto-mono-stripped.ttf") as &[u8]); + let font = rusttype::Font::try_from_vec(font); for (class_index, bboxes_for_class) in bboxes.iter().enumerate() { for b in bboxes_for_class.iter() { println!( @@ -114,6 +117,29 @@ pub fn report_detect( image::Rgb([255, 0, 0]), ); } + if legend_size > 0 { + if let Some(font) = font.as_ref() { + imageproc::drawing::draw_filled_rect_mut( + &mut img, + imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, legend_size), + image::Rgb([170, 0, 0]), + ); + let legend = format!( + "{} {:.0}%", + candle_examples::coco_classes::NAMES[class_index], + 100. * b.confidence + ); + imageproc::drawing::draw_text_mut( + &mut img, + image::Rgb([255, 255, 255]), + xmin, + ymin, + rusttype::Scale::uniform(legend_size as f32 - 1.), + font, + &legend, + ) + } + } } } Ok(DynamicImage::ImageRgb8(img)) @@ -255,6 +281,10 @@ pub struct Args { /// The task to be run. #[arg(long, default_value = "detect")] task: YoloTask, + + /// The size for the legend, 0 means no legend. + #[arg(long, default_value_t = 14)] + legend_size: u32, } impl Args { @@ -291,6 +321,7 @@ pub trait Task: Module + Sized { h: usize, confidence_threshold: f32, nms_threshold: f32, + legend_size: u32, ) -> Result; } @@ -306,8 +337,17 @@ impl Task for YoloV8 { h: usize, confidence_threshold: f32, nms_threshold: f32, + legend_size: u32, ) -> Result { - report_detect(pred, img, w, h, confidence_threshold, nms_threshold) + report_detect( + pred, + img, + w, + h, + confidence_threshold, + nms_threshold, + legend_size, + ) } } @@ -323,6 +363,7 @@ impl Task for YoloV8Pose { h: usize, confidence_threshold: f32, nms_threshold: f32, + _legend_size: u32, ) -> Result { report_pose(pred, img, w, h, confidence_threshold, nms_threshold) } @@ -385,6 +426,7 @@ pub fn run(args: Args) -> anyhow::Result<()> { height, args.confidence_threshold, args.nms_threshold, + args.legend_size, )?; image_name.set_extension("pp.jpg"); println!("writing {image_name:?}"); diff --git a/candle-examples/examples/yolo-v8/roboto-mono-stripped.ttf b/candle-examples/examples/yolo-v8/roboto-mono-stripped.ttf new file mode 100644 index 00000000..6d807fc4 Binary files /dev/null and b/candle-examples/examples/yolo-v8/roboto-mono-stripped.ttf differ