Support for timegroupnorm in encodec. (#1291)

This commit is contained in:
Laurent Mazare
2023-11-07 22:39:59 +01:00
committed by GitHub
parent d4a45c936a
commit 7920b45c8a

View File

@ -8,6 +8,7 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
enum NormType { enum NormType {
WeightNorm, WeightNorm,
TimeGroupNorm,
None, None,
} }
@ -268,6 +269,7 @@ impl Module for EncodecConvTranspose1d {
struct EncodecConv1d { struct EncodecConv1d {
causal: bool, causal: bool,
conv: Conv1d, conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
} }
impl EncodecConv1d { impl EncodecConv1d {
@ -292,7 +294,7 @@ impl EncodecConv1d {
}, },
vb.pp("conv"), vb.pp("conv"),
)?, )?,
NormType::None => conv1d( NormType::None | NormType::TimeGroupNorm => conv1d(
in_c, in_c,
out_c, out_c,
kernel_size, kernel_size,
@ -305,9 +307,17 @@ impl EncodecConv1d {
vb.pp("conv"), vb.pp("conv"),
)?, )?,
}; };
let norm = match cfg.norm_type {
NormType::None | NormType::WeightNorm => None,
NormType::TimeGroupNorm => {
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(gn)
}
};
Ok(Self { Ok(Self {
causal: cfg.use_causal_conv, causal: cfg.use_causal_conv,
conv, conv,
norm,
}) })
} }
} }
@ -316,8 +326,10 @@ impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal. // TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?; let xs = self.conv.forward(xs)?;
// If we add support for NormType "time_group_norm", we should add some normalization here. match &self.norm {
Ok(xs) None => Ok(xs),
Some(norm) => xs.apply(norm),
}
} }
} }