mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add fuse-conv-bn method for Conv2d (#1196)
* Add fuse-conv-bn method for Conv2d * no unwrap * run rustfmp and clippy
This commit is contained in:
@ -1,7 +1,5 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
|
||||
};
|
||||
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Conv2d, Conv2dConfig, Module, VarBuilder};
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Debug)]
|
||||
pub struct Multiples {
|
||||
@ -76,7 +74,6 @@ impl Module for Upsample {
|
||||
#[derive(Debug)]
|
||||
struct ConvBlock {
|
||||
conv: Conv2d,
|
||||
bn: BatchNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
@ -96,11 +93,10 @@ impl ConvBlock {
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
||||
Ok(Self {
|
||||
conv,
|
||||
bn,
|
||||
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
|
||||
})
|
||||
}
|
||||
@ -110,7 +106,6 @@ impl Module for ConvBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.conv.forward(xs)?;
|
||||
let xs = self.bn.forward(&xs)?;
|
||||
candle_nn::ops::silu(&xs)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user