mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Some fixes for yolo-v3. (#529)
* Some fixes for yolo-v3. * Use the running stats for inference in the batch-norm layer. * Get some proper predictions for yolo. * Avoid the quadratic insertion.
This commit is contained in:
@ -40,6 +40,8 @@ impl From<f64> for BatchNormConfig {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BatchNorm {
|
||||
running_mean: Tensor,
|
||||
running_var: Tensor,
|
||||
weight_and_bias: Option<(Tensor, Tensor)>,
|
||||
remove_mean: bool,
|
||||
eps: f64,
|
||||
@ -47,7 +49,14 @@ pub struct BatchNorm {
|
||||
}
|
||||
|
||||
impl BatchNorm {
|
||||
pub fn new(num_features: usize, weight: Tensor, bias: Tensor, eps: f64) -> Result<Self> {
|
||||
pub fn new(
|
||||
num_features: usize,
|
||||
running_mean: Tensor,
|
||||
running_var: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
) -> Result<Self> {
|
||||
if eps < 0. {
|
||||
candle::bail!("batch-norm eps cannot be negative {eps}")
|
||||
}
|
||||
@ -64,6 +73,8 @@ impl BatchNorm {
|
||||
)
|
||||
}
|
||||
Ok(Self {
|
||||
running_mean,
|
||||
running_var,
|
||||
weight_and_bias: Some((weight, bias)),
|
||||
remove_mean: true,
|
||||
eps,
|
||||
@ -71,11 +82,18 @@ impl BatchNorm {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_no_bias(num_features: usize, eps: f64) -> Result<Self> {
|
||||
pub fn new_no_bias(
|
||||
num_features: usize,
|
||||
running_mean: Tensor,
|
||||
running_var: Tensor,
|
||||
eps: f64,
|
||||
) -> Result<Self> {
|
||||
if eps < 0. {
|
||||
candle::bail!("batch-norm eps cannot be negative {eps}")
|
||||
}
|
||||
Ok(Self {
|
||||
running_mean,
|
||||
running_var,
|
||||
weight_and_bias: None,
|
||||
remove_mean: true,
|
||||
eps,
|
||||
@ -84,8 +102,8 @@ impl BatchNorm {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for BatchNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
impl BatchNorm {
|
||||
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
@ -129,6 +147,29 @@ impl crate::Module for BatchNorm {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for BatchNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let target_shape: Vec<usize> = x
|
||||
.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
|
||||
.collect();
|
||||
let target_shape = target_shape.as_slice();
|
||||
let x = x
|
||||
.broadcast_sub(&self.running_mean.reshape(target_shape)?)?
|
||||
.broadcast_div(&(self.running_var.reshape(target_shape)? + self.eps)?.sqrt()?)?;
|
||||
match &self.weight_and_bias {
|
||||
None => Ok(x),
|
||||
Some((weight, bias)) => {
|
||||
let weight = weight.reshape(target_shape)?;
|
||||
let bias = bias.reshape(target_shape)?;
|
||||
x.broadcast_mul(&weight)?.broadcast_add(&bias)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||
num_features: usize,
|
||||
config: C,
|
||||
@ -138,6 +179,8 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||
if config.eps < 0. {
|
||||
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
||||
}
|
||||
let running_mean = vb.get_or_init(num_features, "running_mean", crate::Init::Const(0.))?;
|
||||
let running_var = vb.get_or_init(num_features, "running_var", crate::Init::Const(1.))?;
|
||||
let weight_and_bias = if config.affine {
|
||||
let weight = vb.get_or_init(num_features, "weight", crate::Init::Const(1.))?;
|
||||
let bias = vb.get_or_init(num_features, "bias", crate::Init::Const(0.))?;
|
||||
@ -146,6 +189,8 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||
None
|
||||
};
|
||||
Ok(BatchNorm {
|
||||
running_mean,
|
||||
running_var,
|
||||
weight_and_bias,
|
||||
remove_mean: config.remove_mean,
|
||||
eps: config.eps,
|
||||
|
@ -7,8 +7,8 @@ extern crate accelerate_src;
|
||||
mod test_utils;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::{BatchNorm, Module};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::BatchNorm;
|
||||
|
||||
/* The test below has been generated using the following PyTorch code:
|
||||
import torch
|
||||
@ -21,7 +21,9 @@ print(output.flatten())
|
||||
*/
|
||||
#[test]
|
||||
fn batch_norm() -> Result<()> {
|
||||
let bn = BatchNorm::new_no_bias(5, 1e-8)?;
|
||||
let running_mean = Tensor::zeros(5, DType::F32, &Device::Cpu)?;
|
||||
let running_var = Tensor::ones(5, DType::F32, &Device::Cpu)?;
|
||||
let bn = BatchNorm::new_no_bias(5, running_mean.clone(), running_var.clone(), 1e-8)?;
|
||||
let input: [f32; 120] = [
|
||||
-0.7493, -1.0410, 1.6977, -0.6579, 1.7982, -0.0087, 0.2812, -0.1190, 0.2908, -0.5975,
|
||||
-0.0278, -0.2138, -1.3130, -1.6048, -2.2028, 0.9452, 0.4002, 0.0831, 1.0004, 0.1860,
|
||||
@ -37,7 +39,7 @@ fn batch_norm() -> Result<()> {
|
||||
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
|
||||
];
|
||||
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
|
||||
let output = bn.forward(&input)?;
|
||||
let output = bn.forward_learning(&input)?;
|
||||
assert_eq!(output.dims(), &[2, 5, 3, 4]);
|
||||
let output = output.flatten_all()?;
|
||||
assert_eq!(
|
||||
@ -59,11 +61,13 @@ fn batch_norm() -> Result<()> {
|
||||
);
|
||||
let bn2 = BatchNorm::new(
|
||||
5,
|
||||
running_mean.clone(),
|
||||
running_var.clone(),
|
||||
Tensor::new(&[0.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
||||
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
||||
1e-8,
|
||||
)?;
|
||||
let output2 = bn2.forward(&input)?;
|
||||
let output2 = bn2.forward_learning(&input)?;
|
||||
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
|
||||
let output2 = output2.flatten_all()?;
|
||||
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
||||
|
Reference in New Issue
Block a user