mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Support for timegroupnorm in encodec. (#1291)
This commit is contained in:
@ -8,6 +8,7 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
|||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
enum NormType {
|
enum NormType {
|
||||||
WeightNorm,
|
WeightNorm,
|
||||||
|
TimeGroupNorm,
|
||||||
None,
|
None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -268,6 +269,7 @@ impl Module for EncodecConvTranspose1d {
|
|||||||
struct EncodecConv1d {
|
struct EncodecConv1d {
|
||||||
causal: bool,
|
causal: bool,
|
||||||
conv: Conv1d,
|
conv: Conv1d,
|
||||||
|
norm: Option<candle_nn::GroupNorm>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecConv1d {
|
impl EncodecConv1d {
|
||||||
@ -292,7 +294,7 @@ impl EncodecConv1d {
|
|||||||
},
|
},
|
||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
NormType::None => conv1d(
|
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||||
in_c,
|
in_c,
|
||||||
out_c,
|
out_c,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
@ -305,9 +307,17 @@ impl EncodecConv1d {
|
|||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
};
|
};
|
||||||
|
let norm = match cfg.norm_type {
|
||||||
|
NormType::None | NormType::WeightNorm => None,
|
||||||
|
NormType::TimeGroupNorm => {
|
||||||
|
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||||
|
Some(gn)
|
||||||
|
}
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
causal: cfg.use_causal_conv,
|
causal: cfg.use_causal_conv,
|
||||||
conv,
|
conv,
|
||||||
|
norm,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -316,8 +326,10 @@ impl Module for EncodecConv1d {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
// TODO: padding, depending on causal.
|
// TODO: padding, depending on causal.
|
||||||
let xs = self.conv.forward(xs)?;
|
let xs = self.conv.forward(xs)?;
|
||||||
// If we add support for NormType "time_group_norm", we should add some normalization here.
|
match &self.norm {
|
||||||
Ok(xs)
|
None => Ok(xs),
|
||||||
|
Some(norm) => xs.apply(norm),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user