mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Some fixes for yolo-v3. (#529)
* Some fixes for yolo-v3. * Use the running stats for inference in the batch-norm layer. * Get some proper predictions for yolo. * Avoid the quadratic insertion.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use candle::{Device, IndexOp, Result, Tensor};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Func, Module, VarBuilder};
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs::File;
|
||||
@ -145,11 +145,12 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
||||
Some(bn) => bn.forward(&xs)?,
|
||||
None => xs,
|
||||
};
|
||||
if leaky {
|
||||
xs.maximum(&(&xs * 0.1)?)
|
||||
let xs = if leaky {
|
||||
xs.maximum(&(&xs * 0.1)?)?
|
||||
} else {
|
||||
Ok(xs)
|
||||
}
|
||||
xs
|
||||
};
|
||||
Ok(xs)
|
||||
});
|
||||
Ok((filters, Bl::Layer(Box::new(func))))
|
||||
}
|
||||
@ -225,12 +226,13 @@ fn detect(
|
||||
let grid = Tensor::arange(0u32, grid_size as u32, &Device::Cpu)?;
|
||||
let a = grid.repeat((grid_size, 1))?;
|
||||
let b = a.t()?.contiguous()?;
|
||||
let x_offset = a.unsqueeze(1)?;
|
||||
let y_offset = b.unsqueeze(1)?;
|
||||
let x_offset = a.flatten_all()?.unsqueeze(1)?;
|
||||
let y_offset = b.flatten_all()?.unsqueeze(1)?;
|
||||
let xy_offset = Tensor::cat(&[&x_offset, &y_offset], 1)?
|
||||
.repeat((1, nanchors))?
|
||||
.reshape((grid_size * grid_size * nanchors, 2))?
|
||||
.unsqueeze(0)?;
|
||||
.unsqueeze(0)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let anchors: Vec<f32> = anchors
|
||||
.iter()
|
||||
.flat_map(|&(x, y)| vec![x as f32 / stride as f32, y as f32 / stride as f32].into_iter())
|
||||
@ -245,7 +247,8 @@ fn detect(
|
||||
let ys02 = (candle_nn::ops::sigmoid(&ys02)?.add(&xy_offset)? * stride as f64)?;
|
||||
let ys24 = (ys24.exp()?.mul(&anchors)? * stride as f64)?;
|
||||
let ys4 = candle_nn::ops::sigmoid(&ys4)?;
|
||||
Tensor::cat(&[ys02, ys24, ys4], 2)
|
||||
let ys = Tensor::cat(&[ys02, ys24, ys4], 2)?;
|
||||
Ok(ys)
|
||||
}
|
||||
|
||||
impl Darknet {
|
||||
|
Reference in New Issue
Block a user