mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Support for timegroupnorm in encodec. (#1291)
This commit is contained in:
@ -8,6 +8,7 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum NormType {
|
||||
WeightNorm,
|
||||
TimeGroupNorm,
|
||||
None,
|
||||
}
|
||||
|
||||
@ -268,6 +269,7 @@ impl Module for EncodecConvTranspose1d {
|
||||
struct EncodecConv1d {
|
||||
causal: bool,
|
||||
conv: Conv1d,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
}
|
||||
|
||||
impl EncodecConv1d {
|
||||
@ -292,7 +294,7 @@ impl EncodecConv1d {
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
NormType::None => conv1d(
|
||||
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
@ -305,9 +307,17 @@ impl EncodecConv1d {
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
};
|
||||
let norm = match cfg.norm_type {
|
||||
NormType::None | NormType::WeightNorm => None,
|
||||
NormType::TimeGroupNorm => {
|
||||
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(gn)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
causal: cfg.use_causal_conv,
|
||||
conv,
|
||||
norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -316,8 +326,10 @@ impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: padding, depending on causal.
|
||||
let xs = self.conv.forward(xs)?;
|
||||
// If we add support for NormType "time_group_norm", we should add some normalization here.
|
||||
Ok(xs)
|
||||
match &self.norm {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user