Move the conv1d layer to candle_nn. (#117)

This commit is contained in:
Laurent Mazare
2023-07-10 11:02:06 +01:00
committed by GitHub
parent b06e1a7e54
commit 89a5b602a6
5 changed files with 122 additions and 134 deletions

View File

@ -1,4 +1,4 @@
use crate::nn::{Conv1D, ConvConfig, VarBuilder};
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
use anyhow::Result;
use candle::Tensor;
@ -221,7 +221,7 @@ impl EncodecConvTranspose1d {
#[derive(Debug)]
struct EncodecConv1d {
conv: Conv1D,
conv: Conv1d,
}
impl EncodecConv1d {
@ -235,19 +235,19 @@ impl EncodecConv1d {
cfg: &Config,
) -> Result<Self> {
let conv = match cfg.norm_type {
NormType::WeightNorm => Conv1D::load_weight_norm(
NormType::WeightNorm => conv1d_weight_norm(
in_c,
out_c,
kernel_size,
ConvConfig { padding: 0, stride },
Conv1dConfig { padding: 0, stride },
&format!("{p}.conv"),
vb,
)?,
NormType::None => Conv1D::load(
NormType::None => conv1d(
in_c,
out_c,
kernel_size,
ConvConfig { padding: 0, stride },
Conv1dConfig { padding: 0, stride },
&format!("{p}.conv"),
vb,
)?,

View File

@ -125,59 +125,39 @@ pub fn embedding(
Ok(Embedding::new(embeddings, hidden_size))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConvConfig {
pub padding: usize,
pub stride: usize,
pub type Conv1d = candle_nn::Conv1d;
pub type Conv1dConfig = candle_nn::Conv1dConfig;
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
pub fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: Conv1dConfig,
p: &str,
vb: &VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), &format!("{p}.weight_g"))?;
let weight_v = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight_v"))?;
let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, &format!("{p}.bias"))?;
Ok(Conv1d::new(weight, Some(bias), config))
}
#[derive(Debug)]
pub struct Conv1D {
weight: Tensor,
bias: Option<Tensor>,
config: ConvConfig,
}
impl Conv1D {
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
pub fn load_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: ConvConfig,
p: &str,
vb: &VarBuilder,
) -> Result<Self> {
let weight_g = vb.get((out_c, 1, 1), &format!("{p}.weight_g"))?;
let weight_v = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight_v"))?;
let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, &format!("{p}.bias"))?;
Ok(Self {
weight,
bias: Some(bias),
config,
})
}
pub fn load(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: ConvConfig,
p: &str,
vb: &VarBuilder,
) -> Result<Self> {
let weight = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight"))?;
let bias = vb.get(out_c, &format!("{p}.bias"))?;
Ok(Self {
weight,
bias: Some(bias),
config,
})
}
pub fn conv1d(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: Conv1dConfig,
p: &str,
vb: &VarBuilder,
) -> Result<Conv1d> {
let weight = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight"))?;
let bias = vb.get(out_c, &format!("{p}.bias"))?;
Ok(Conv1d::new(weight, Some(bias), config))
}
pub type HiddenAct = candle_nn::Activation;