From 7920b45c8ac737b67e23f04297f6bd7e4860f373 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 7 Nov 2023 22:39:59 +0100 Subject: [PATCH] Support for timegroupnorm in encodec. (#1291) --- .../examples/musicgen/encodec_model.rs | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index 095c90a9..60149e45 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -8,6 +8,7 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder}; #[derive(Debug, Clone, PartialEq)] enum NormType { WeightNorm, + TimeGroupNorm, None, } @@ -268,6 +269,7 @@ impl Module for EncodecConvTranspose1d { struct EncodecConv1d { causal: bool, conv: Conv1d, + norm: Option, } impl EncodecConv1d { @@ -292,7 +294,7 @@ impl EncodecConv1d { }, vb.pp("conv"), )?, - NormType::None => conv1d( + NormType::None | NormType::TimeGroupNorm => conv1d( in_c, out_c, kernel_size, @@ -305,9 +307,17 @@ impl EncodecConv1d { vb.pp("conv"), )?, }; + let norm = match cfg.norm_type { + NormType::None | NormType::WeightNorm => None, + NormType::TimeGroupNorm => { + let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?; + Some(gn) + } + }; Ok(Self { causal: cfg.use_causal_conv, conv, + norm, }) } } @@ -316,8 +326,10 @@ impl Module for EncodecConv1d { fn forward(&self, xs: &Tensor) -> Result { // TODO: padding, depending on causal. let xs = self.conv.forward(xs)?; - // If we add support for NormType "time_group_norm", we should add some normalization here. - Ok(xs) + match &self.norm { + None => Ok(xs), + Some(norm) => xs.apply(norm), + } } }