Some fixes for yolo-v3. (#529)

* Some fixes for yolo-v3.

* Use the running stats for inference in the batch-norm layer.

* Get some proper predictions for yolo.

* Avoid the quadratic insertion.
This commit is contained in:
Laurent Mazare
2023-08-20 23:19:15 +01:00
committed by GitHub
parent a1812f934f
commit 11c7e7bd67
6 changed files with 144 additions and 53 deletions

View File

@ -1,4 +1,4 @@
use candle::{Device, IndexOp, Result, Tensor};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Func, Module, VarBuilder};
use std::collections::BTreeMap;
use std::fs::File;
@ -145,11 +145,12 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
Some(bn) => bn.forward(&xs)?,
None => xs,
};
if leaky {
xs.maximum(&(&xs * 0.1)?)
let xs = if leaky {
xs.maximum(&(&xs * 0.1)?)?
} else {
Ok(xs)
}
xs
};
Ok(xs)
});
Ok((filters, Bl::Layer(Box::new(func))))
}
@ -225,12 +226,13 @@ fn detect(
let grid = Tensor::arange(0u32, grid_size as u32, &Device::Cpu)?;
let a = grid.repeat((grid_size, 1))?;
let b = a.t()?.contiguous()?;
let x_offset = a.unsqueeze(1)?;
let y_offset = b.unsqueeze(1)?;
let x_offset = a.flatten_all()?.unsqueeze(1)?;
let y_offset = b.flatten_all()?.unsqueeze(1)?;
let xy_offset = Tensor::cat(&[&x_offset, &y_offset], 1)?
.repeat((1, nanchors))?
.reshape((grid_size * grid_size * nanchors, 2))?
.unsqueeze(0)?;
.unsqueeze(0)?
.to_dtype(DType::F32)?;
let anchors: Vec<f32> = anchors
.iter()
.flat_map(|&(x, y)| vec![x as f32 / stride as f32, y as f32 / stride as f32].into_iter())
@ -245,7 +247,8 @@ fn detect(
let ys02 = (candle_nn::ops::sigmoid(&ys02)?.add(&xy_offset)? * stride as f64)?;
let ys24 = (ys24.exp()?.mul(&anchors)? * stride as f64)?;
let ys4 = candle_nn::ops::sigmoid(&ys4)?;
Tensor::cat(&[ys02, ys24, ys4], 2)
let ys = Tensor::cat(&[ys02, ys24, ys4], 2)?;
Ok(ys)
}
impl Darknet {