Some fixes for yolo-v3. (#529)

* Some fixes for yolo-v3.

* Use the running stats for inference in the batch-norm layer.

* Get some proper predictions for yolo.

* Avoid the quadratic insertion.
This commit is contained in:
Laurent Mazare
2023-08-20 23:19:15 +01:00
committed by GitHub
parent a1812f934f
commit 11c7e7bd67
6 changed files with 144 additions and 53 deletions

View File

@ -1,4 +1,4 @@
use candle::{Device, IndexOp, Result, Tensor};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Func, Module, VarBuilder};
use std::collections::BTreeMap;
use std::fs::File;
@ -145,11 +145,12 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
Some(bn) => bn.forward(&xs)?,
None => xs,
};
if leaky {
xs.maximum(&(&xs * 0.1)?)
let xs = if leaky {
xs.maximum(&(&xs * 0.1)?)?
} else {
Ok(xs)
}
xs
};
Ok(xs)
});
Ok((filters, Bl::Layer(Box::new(func))))
}
@ -225,12 +226,13 @@ fn detect(
let grid = Tensor::arange(0u32, grid_size as u32, &Device::Cpu)?;
let a = grid.repeat((grid_size, 1))?;
let b = a.t()?.contiguous()?;
let x_offset = a.unsqueeze(1)?;
let y_offset = b.unsqueeze(1)?;
let x_offset = a.flatten_all()?.unsqueeze(1)?;
let y_offset = b.flatten_all()?.unsqueeze(1)?;
let xy_offset = Tensor::cat(&[&x_offset, &y_offset], 1)?
.repeat((1, nanchors))?
.reshape((grid_size * grid_size * nanchors, 2))?
.unsqueeze(0)?;
.unsqueeze(0)?
.to_dtype(DType::F32)?;
let anchors: Vec<f32> = anchors
.iter()
.flat_map(|&(x, y)| vec![x as f32 / stride as f32, y as f32 / stride as f32].into_iter())
@ -245,7 +247,8 @@ fn detect(
let ys02 = (candle_nn::ops::sigmoid(&ys02)?.add(&xy_offset)? * stride as f64)?;
let ys24 = (ys24.exp()?.mul(&anchors)? * stride as f64)?;
let ys4 = candle_nn::ops::sigmoid(&ys4)?;
Tensor::cat(&[ys02, ys24, ys4], 2)
let ys = Tensor::cat(&[ys02, ys24, ys4], 2)?;
Ok(ys)
}
impl Darknet {

View File

@ -0,0 +1,7 @@
def remove_prefix(text, prefix):
return text[text.startswith(prefix) and len(prefix):]
nps = {}
for k, v in model.state_dict().items():
k = remove_prefix(k, 'module_list.')
nps[k] = v.detach().numpy()
np.savez('yolo-v3.ot', **nps)

View File

@ -11,22 +11,22 @@ use anyhow::Result;
use candle::{DType, Device, Tensor};
use candle_nn::{Module, VarBuilder};
use clap::Parser;
use image::{DynamicImage, ImageBuffer};
const CONFIG_NAME: &str = "candle-examples/examples/yolo/yolo-v3.cfg";
const CONFIDENCE_THRESHOLD: f64 = 0.5;
const NMS_THRESHOLD: f64 = 0.4;
const CONFIDENCE_THRESHOLD: f32 = 0.5;
const NMS_THRESHOLD: f32 = 0.4;
#[derive(Debug, Clone, Copy)]
struct Bbox {
xmin: f64,
ymin: f64,
xmax: f64,
ymax: f64,
confidence: f64,
xmin: f32,
ymin: f32,
xmax: f32,
ymax: f32,
confidence: f32,
}
// Intersection over union of two bounding boxes.
fn iou(b1: &Bbox, b2: &Bbox) -> f64 {
fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
@ -38,18 +38,35 @@ fn iou(b1: &Bbox, b2: &Bbox) -> f64 {
}
// Assumes x1 <= x2 and y1 <= y2
pub fn draw_rect(_: &mut Tensor, _x1: usize, _x2: usize, _y1: usize, _y2: usize) {
todo!()
pub fn draw_rect(
img: &mut ImageBuffer<image::Rgb<u8>, Vec<u8>>,
x1: u32,
x2: u32,
y1: u32,
y2: u32,
) {
for x in x1..=x2 {
let pixel = img.get_pixel_mut(x, y1);
*pixel = image::Rgb([255, 0, 0]);
let pixel = img.get_pixel_mut(x, y2);
*pixel = image::Rgb([255, 0, 0]);
}
for y in y1..=y2 {
let pixel = img.get_pixel_mut(x1, y);
*pixel = image::Rgb([255, 0, 0]);
let pixel = img.get_pixel_mut(x2, y);
*pixel = image::Rgb([255, 0, 0]);
}
}
pub fn report(pred: &Tensor, img: &Tensor, w: usize, h: usize) -> Result<Tensor> {
pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<DynamicImage> {
let (npreds, pred_size) = pred.dims2()?;
let nclasses = pred_size - 5;
// The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f64>::try_from(pred.get(index)?)?;
let pred = Vec::<f32>::try_from(pred.get(index)?)?;
let confidence = pred[4];
if confidence > CONFIDENCE_THRESHOLD {
let mut class_index = 0;
@ -91,24 +108,21 @@ pub fn report(pred: &Tensor, img: &Tensor, w: usize, h: usize) -> Result<Tensor>
bboxes_for_class.truncate(current_index);
}
// Annotate the original image and print boxes information.
let (_, initial_h, initial_w) = img.dims3()?;
let mut img = (img.to_dtype(DType::F32)? * (1. / 255.))?;
let w_ratio = initial_w as f64 / w as f64;
let h_ratio = initial_h as f64 / h as f64;
let (initial_h, initial_w) = (img.height(), img.width());
let w_ratio = initial_w as f32 / w as f32;
let h_ratio = initial_h as f32 / h as f32;
let mut img = img.to_rgb8();
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
println!("{}: {:?}", coco_classes::NAMES[class_index], b);
let xmin = ((b.xmin * w_ratio) as usize).clamp(0, initial_w - 1);
let ymin = ((b.ymin * h_ratio) as usize).clamp(0, initial_h - 1);
let xmax = ((b.xmax * w_ratio) as usize).clamp(0, initial_w - 1);
let ymax = ((b.ymax * h_ratio) as usize).clamp(0, initial_h - 1);
draw_rect(&mut img, xmin, xmax, ymin, ymax.min(ymin + 2));
draw_rect(&mut img, xmin, xmax, ymin.max(ymax - 2), ymax);
draw_rect(&mut img, xmin, xmax.min(xmin + 2), ymin, ymax);
draw_rect(&mut img, xmin.max(xmax - 2), xmax, ymin, ymax);
let xmin = ((b.xmin * w_ratio) as u32).clamp(0, initial_w - 1);
let ymin = ((b.ymin * h_ratio) as u32).clamp(0, initial_h - 1);
let xmax = ((b.xmax * w_ratio) as u32).clamp(0, initial_w - 1);
let ymax = ((b.ymax * h_ratio) as u32).clamp(0, initial_h - 1);
draw_rect(&mut img, xmin, xmax, ymin, ymax);
}
}
Ok((img * 255.)?.to_dtype(DType::U8)?)
Ok(DynamicImage::ImageRgb8(img))
}
#[derive(Parser, Debug)]
@ -118,6 +132,9 @@ struct Args {
#[arg(long)]
model: String,
#[arg(long)]
config: String,
images: Vec<String>,
}
@ -128,18 +145,36 @@ pub fn main() -> Result<()> {
let weights = unsafe { candle::safetensors::MmapedFile::new(&args.model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
let darknet = darknet::parse_config(CONFIG_NAME)?;
let darknet = darknet::parse_config(&args.config)?;
let model = darknet.build_model(vb)?;
for image in args.images.iter() {
for image_name in args.images.iter() {
println!("processing {image_name}");
let mut image_name = std::path::PathBuf::from(image_name);
// Load the image file and resize it.
let net_width = darknet.width()?;
let net_height = darknet.height()?;
let image = candle_examples::load_image_and_resize(image, net_width, net_height)?;
let original_image = image::io::Reader::open(&image_name)?
.decode()
.map_err(candle::Error::wrap)?;
let image = {
let data = original_image
.resize_exact(
net_width as u32,
net_height as u32,
image::imageops::FilterType::Triangle,
)
.to_rgb8()
.into_raw();
Tensor::from_vec(data, (net_width, net_height, 3), &Device::Cpu)?.permute((2, 0, 1))?
};
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
let predictions = model.forward(&image)?.squeeze(0)?;
let _image = report(&predictions, &image, net_width, net_height)?;
println!("converted {image}");
let image = report(&predictions, original_image, net_width, net_height)?;
image_name.set_extension("pp.jpg");
println!("writing {image_name:?}");
image.save(image_name)?
}
Ok(())
}