Add some configurable legend for yolo detection. (#603)

* Add some configurable legend for yolo detection.

* Clippyness.
This commit is contained in:
Laurent Mazare
2023-08-25 13:50:31 +01:00
committed by GitHub
parent 97909e5068
commit 0afbc435df
4 changed files with 45 additions and 1 deletions

View File

@ -46,6 +46,7 @@ num-traits = "0.2.15"
rand = "0.8.5"
rand_distr = "0.4.3"
rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false }
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.99"

View File

@ -33,6 +33,7 @@ hf-hub = { workspace = true, features=["tokio"]}
imageproc = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
rusttype = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }

View File

@ -59,6 +59,7 @@ pub fn report_detect(
h: usize,
confidence_threshold: f32,
nms_threshold: f32,
legend_size: u32,
) -> Result<DynamicImage> {
let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4;
@ -96,6 +97,8 @@ pub fn report_detect(
let w_ratio = initial_w as f32 / w as f32;
let h_ratio = initial_h as f32 / h as f32;
let mut img = img.to_rgb8();
let font = Vec::from(include_bytes!("roboto-mono-stripped.ttf") as &[u8]);
let font = rusttype::Font::try_from_vec(font);
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
println!(
@ -114,6 +117,29 @@ pub fn report_detect(
image::Rgb([255, 0, 0]),
);
}
if legend_size > 0 {
if let Some(font) = font.as_ref() {
imageproc::drawing::draw_filled_rect_mut(
&mut img,
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, legend_size),
image::Rgb([170, 0, 0]),
);
let legend = format!(
"{} {:.0}%",
candle_examples::coco_classes::NAMES[class_index],
100. * b.confidence
);
imageproc::drawing::draw_text_mut(
&mut img,
image::Rgb([255, 255, 255]),
xmin,
ymin,
rusttype::Scale::uniform(legend_size as f32 - 1.),
font,
&legend,
)
}
}
}
}
Ok(DynamicImage::ImageRgb8(img))
@ -255,6 +281,10 @@ pub struct Args {
/// The task to be run.
#[arg(long, default_value = "detect")]
task: YoloTask,
/// The size for the legend, 0 means no legend.
#[arg(long, default_value_t = 14)]
legend_size: u32,
}
impl Args {
@ -291,6 +321,7 @@ pub trait Task: Module + Sized {
h: usize,
confidence_threshold: f32,
nms_threshold: f32,
legend_size: u32,
) -> Result<DynamicImage>;
}
@ -306,8 +337,17 @@ impl Task for YoloV8 {
h: usize,
confidence_threshold: f32,
nms_threshold: f32,
legend_size: u32,
) -> Result<DynamicImage> {
report_detect(pred, img, w, h, confidence_threshold, nms_threshold)
report_detect(
pred,
img,
w,
h,
confidence_threshold,
nms_threshold,
legend_size,
)
}
}
@ -323,6 +363,7 @@ impl Task for YoloV8Pose {
h: usize,
confidence_threshold: f32,
nms_threshold: f32,
_legend_size: u32,
) -> Result<DynamicImage> {
report_pose(pred, img, w, h, confidence_threshold, nms_threshold)
}
@ -385,6 +426,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
height,
args.confidence_threshold,
args.nms_threshold,
args.legend_size,
)?;
image_name.set_extension("pp.jpg");
println!("writing {image_name:?}");