mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Some fixes for yolo-v3. (#529)
* Some fixes for yolo-v3. * Use the running stats for inference in the batch-norm layer. * Get some proper predictions for yolo. * Avoid the quadratic insertion.
This commit is contained in:
@ -40,6 +40,8 @@ impl From<f64> for BatchNormConfig {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BatchNorm {
|
||||
running_mean: Tensor,
|
||||
running_var: Tensor,
|
||||
weight_and_bias: Option<(Tensor, Tensor)>,
|
||||
remove_mean: bool,
|
||||
eps: f64,
|
||||
@ -47,7 +49,14 @@ pub struct BatchNorm {
|
||||
}
|
||||
|
||||
impl BatchNorm {
|
||||
pub fn new(num_features: usize, weight: Tensor, bias: Tensor, eps: f64) -> Result<Self> {
|
||||
pub fn new(
|
||||
num_features: usize,
|
||||
running_mean: Tensor,
|
||||
running_var: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
) -> Result<Self> {
|
||||
if eps < 0. {
|
||||
candle::bail!("batch-norm eps cannot be negative {eps}")
|
||||
}
|
||||
@ -64,6 +73,8 @@ impl BatchNorm {
|
||||
)
|
||||
}
|
||||
Ok(Self {
|
||||
running_mean,
|
||||
running_var,
|
||||
weight_and_bias: Some((weight, bias)),
|
||||
remove_mean: true,
|
||||
eps,
|
||||
@ -71,11 +82,18 @@ impl BatchNorm {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_no_bias(num_features: usize, eps: f64) -> Result<Self> {
|
||||
pub fn new_no_bias(
|
||||
num_features: usize,
|
||||
running_mean: Tensor,
|
||||
running_var: Tensor,
|
||||
eps: f64,
|
||||
) -> Result<Self> {
|
||||
if eps < 0. {
|
||||
candle::bail!("batch-norm eps cannot be negative {eps}")
|
||||
}
|
||||
Ok(Self {
|
||||
running_mean,
|
||||
running_var,
|
||||
weight_and_bias: None,
|
||||
remove_mean: true,
|
||||
eps,
|
||||
@ -84,8 +102,8 @@ impl BatchNorm {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for BatchNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
impl BatchNorm {
|
||||
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
@ -129,6 +147,29 @@ impl crate::Module for BatchNorm {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for BatchNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let target_shape: Vec<usize> = x
|
||||
.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
|
||||
.collect();
|
||||
let target_shape = target_shape.as_slice();
|
||||
let x = x
|
||||
.broadcast_sub(&self.running_mean.reshape(target_shape)?)?
|
||||
.broadcast_div(&(self.running_var.reshape(target_shape)? + self.eps)?.sqrt()?)?;
|
||||
match &self.weight_and_bias {
|
||||
None => Ok(x),
|
||||
Some((weight, bias)) => {
|
||||
let weight = weight.reshape(target_shape)?;
|
||||
let bias = bias.reshape(target_shape)?;
|
||||
x.broadcast_mul(&weight)?.broadcast_add(&bias)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||
num_features: usize,
|
||||
config: C,
|
||||
@ -138,6 +179,8 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||
if config.eps < 0. {
|
||||
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
||||
}
|
||||
let running_mean = vb.get_or_init(num_features, "running_mean", crate::Init::Const(0.))?;
|
||||
let running_var = vb.get_or_init(num_features, "running_var", crate::Init::Const(1.))?;
|
||||
let weight_and_bias = if config.affine {
|
||||
let weight = vb.get_or_init(num_features, "weight", crate::Init::Const(1.))?;
|
||||
let bias = vb.get_or_init(num_features, "bias", crate::Init::Const(0.))?;
|
||||
@ -146,6 +189,8 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||
None
|
||||
};
|
||||
Ok(BatchNorm {
|
||||
running_mean,
|
||||
running_var,
|
||||
weight_and_bias,
|
||||
remove_mean: config.remove_mean,
|
||||
eps: config.eps,
|
||||
|
Reference in New Issue
Block a user