Support for timegroupnorm in encodec. (#1291)

This commit is contained in:
Laurent Mazare
2023-11-07 22:39:59 +01:00
committed by GitHub
parent d4a45c936a
commit 7920b45c8a

View File

@ -8,6 +8,7 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
#[derive(Debug, Clone, PartialEq)]
enum NormType {
WeightNorm,
TimeGroupNorm,
None,
}
@ -268,6 +269,7 @@ impl Module for EncodecConvTranspose1d {
struct EncodecConv1d {
causal: bool,
conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
}
impl EncodecConv1d {
@ -292,7 +294,7 @@ impl EncodecConv1d {
},
vb.pp("conv"),
)?,
NormType::None => conv1d(
NormType::None | NormType::TimeGroupNorm => conv1d(
in_c,
out_c,
kernel_size,
@ -305,9 +307,17 @@ impl EncodecConv1d {
vb.pp("conv"),
)?,
};
let norm = match cfg.norm_type {
NormType::None | NormType::WeightNorm => None,
NormType::TimeGroupNorm => {
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(gn)
}
};
Ok(Self {
causal: cfg.use_causal_conv,
conv,
norm,
})
}
}
@ -316,8 +326,10 @@ impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?;
// If we add support for NormType "time_group_norm", we should add some normalization here.
Ok(xs)
match &self.norm {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}