Move the conv1d layer to candle_nn. (#117)

This commit is contained in:
Laurent Mazare
2023-07-10 11:02:06 +01:00
committed by GitHub
parent b06e1a7e54
commit 89a5b602a6
5 changed files with 122 additions and 134 deletions

View File

@ -1,4 +1,4 @@
use crate::nn::{Conv1D, ConvConfig, VarBuilder};
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
use anyhow::Result;
use candle::Tensor;
@ -221,7 +221,7 @@ impl EncodecConvTranspose1d {
#[derive(Debug)]
struct EncodecConv1d {
conv: Conv1D,
conv: Conv1d,
}
impl EncodecConv1d {
@ -235,19 +235,19 @@ impl EncodecConv1d {
cfg: &Config,
) -> Result<Self> {
let conv = match cfg.norm_type {
NormType::WeightNorm => Conv1D::load_weight_norm(
NormType::WeightNorm => conv1d_weight_norm(
in_c,
out_c,
kernel_size,
ConvConfig { padding: 0, stride },
Conv1dConfig { padding: 0, stride },
&format!("{p}.conv"),
vb,
)?,
NormType::None => Conv1D::load(
NormType::None => conv1d(
in_c,
out_c,
kernel_size,
ConvConfig { padding: 0, stride },
Conv1dConfig { padding: 0, stride },
&format!("{p}.conv"),
vb,
)?,