mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Add fuse-conv-bn method for Conv2d (#1196)
* Add fuse-conv-bn method for Conv2d * no unwrap * run rustfmp and clippy
This commit is contained in:
@ -109,6 +109,10 @@ impl BatchNorm {
|
||||
&self.running_var
|
||||
}
|
||||
|
||||
pub fn eps(&self) -> f64 {
|
||||
self.eps
|
||||
}
|
||||
|
||||
pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> {
|
||||
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
//! Convolution Layers.
|
||||
use crate::BatchNorm;
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@ -115,6 +116,26 @@ impl Conv2d {
|
||||
pub fn bias(&self) -> Option<&Tensor> {
|
||||
self.bias.as_ref()
|
||||
}
|
||||
|
||||
pub fn absorb_bn(&self, bn: &BatchNorm) -> Result<Self> {
|
||||
if let Some((w_bn, b_bn)) = bn.weight_and_bias() {
|
||||
let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?;
|
||||
let weight = self
|
||||
.weight()
|
||||
.broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?;
|
||||
let bias = match &self.bias {
|
||||
None => b_bn.sub(&(std_.mul(bn.running_mean())?))?,
|
||||
Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?,
|
||||
};
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias: Some(bias),
|
||||
config: self.config,
|
||||
})
|
||||
} else {
|
||||
candle::bail!("batch norm does not have weight_and_bias")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for Conv2d {
|
||||
|
Reference in New Issue
Block a user