mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Move the conv1d layer to candle_nn. (#117)
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::nn::{Conv1D, ConvConfig, VarBuilder};
|
||||
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::Tensor;
|
||||
|
||||
@ -221,7 +221,7 @@ impl EncodecConvTranspose1d {
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecConv1d {
|
||||
conv: Conv1D,
|
||||
conv: Conv1d,
|
||||
}
|
||||
|
||||
impl EncodecConv1d {
|
||||
@ -235,19 +235,19 @@ impl EncodecConv1d {
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let conv = match cfg.norm_type {
|
||||
NormType::WeightNorm => Conv1D::load_weight_norm(
|
||||
NormType::WeightNorm => conv1d_weight_norm(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
ConvConfig { padding: 0, stride },
|
||||
Conv1dConfig { padding: 0, stride },
|
||||
&format!("{p}.conv"),
|
||||
vb,
|
||||
)?,
|
||||
NormType::None => Conv1D::load(
|
||||
NormType::None => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
ConvConfig { padding: 0, stride },
|
||||
Conv1dConfig { padding: 0, stride },
|
||||
&format!("{p}.conv"),
|
||||
vb,
|
||||
)?,
|
||||
|
@ -125,59 +125,39 @@ pub fn embedding(
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct ConvConfig {
|
||||
pub padding: usize,
|
||||
pub stride: usize,
|
||||
pub type Conv1d = candle_nn::Conv1d;
|
||||
pub type Conv1dConfig = candle_nn::Conv1dConfig;
|
||||
|
||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||
// does not apply to training.
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
|
||||
pub fn conv1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight_g = vb.get((out_c, 1, 1), &format!("{p}.weight_g"))?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight_v"))?;
|
||||
let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Conv1D {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
config: ConvConfig,
|
||||
}
|
||||
|
||||
impl Conv1D {
|
||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||
// does not apply to training.
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
|
||||
pub fn load_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: ConvConfig,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let weight_g = vb.get((out_c, 1, 1), &format!("{p}.weight_g"))?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight_v"))?;
|
||||
let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias: Some(bias),
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: ConvConfig,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let weight = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight"))?;
|
||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias: Some(bias),
|
||||
config,
|
||||
})
|
||||
}
|
||||
pub fn conv1d(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight"))?;
|
||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
pub type HiddenAct = candle_nn::Activation;
|
||||
|
@ -2,7 +2,7 @@
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -112,78 +112,35 @@ fn linear_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Resul
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct ConvConfig {
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
fn conv1d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = vb.get(
|
||||
(out_channels, in_channels, kernel_size),
|
||||
&format!("{p}.weight"),
|
||||
)?;
|
||||
let bias = vb.get(out_channels, &format!("{p}.bias"))?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
impl Default for ConvConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv1D {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
config: ConvConfig,
|
||||
}
|
||||
|
||||
impl Conv1D {
|
||||
fn load(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
config: ConvConfig,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let weight = vb.get(
|
||||
(out_channels, in_channels, kernel_size),
|
||||
&format!("{p}.weight"),
|
||||
)?;
|
||||
let bias = vb.get(out_channels, &format!("{p}.bias"))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias: Some(bias),
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_no_bias(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
config: ConvConfig,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let weight = vb.get(
|
||||
(out_channels, in_channels, kernel_size),
|
||||
&format!("{p}.weight"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias: None,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => {
|
||||
let b = bias.shape().r1()?;
|
||||
let bias = bias.reshape((1, b, 1))?;
|
||||
Ok(x.broadcast_add(&bias)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
fn conv1d_no_bias(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = vb.get(
|
||||
(out_channels, in_channels, kernel_size),
|
||||
&format!("{p}.weight"),
|
||||
)?;
|
||||
Ok(Conv1d::new(weight, None, config))
|
||||
}
|
||||
|
||||
struct Dropout {
|
||||
@ -338,8 +295,8 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
||||
pub struct AudioEncoder {
|
||||
conv1: Conv1D,
|
||||
conv2: Conv1D,
|
||||
conv1: Conv1d,
|
||||
conv2: Conv1d,
|
||||
positional_embedding: Tensor,
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln_post: LayerNorm,
|
||||
@ -350,15 +307,15 @@ impl AudioEncoder {
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.encoder_attention_heads;
|
||||
let n_ctx = cfg.max_source_positions;
|
||||
let cfg1 = ConvConfig {
|
||||
let cfg1 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
};
|
||||
let cfg2 = ConvConfig {
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
};
|
||||
let conv1 = Conv1D::load(
|
||||
let conv1 = conv1d(
|
||||
cfg.num_mel_bins,
|
||||
n_state,
|
||||
3,
|
||||
@ -366,7 +323,7 @@ impl AudioEncoder {
|
||||
&format!("{p}.conv1"),
|
||||
vb,
|
||||
)?;
|
||||
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?;
|
||||
let blocks = (0..cfg.encoder_layers)
|
||||
.map(|i| {
|
||||
|
49
candle-nn/src/conv.rs
Normal file
49
candle-nn/src/conv.rs
Normal file
@ -0,0 +1,49 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct Conv1dConfig {
|
||||
pub padding: usize,
|
||||
pub stride: usize,
|
||||
}
|
||||
|
||||
impl Default for Conv1dConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Conv1d {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
config: Conv1dConfig,
|
||||
}
|
||||
|
||||
impl Conv1d {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>, config: Conv1dConfig) -> Self {
|
||||
Self {
|
||||
weight,
|
||||
bias,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &Conv1dConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => {
|
||||
let b = bias.shape().r1()?;
|
||||
let bias = bias.reshape((1, b, 1))?;
|
||||
Ok(x.broadcast_add(&bias)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,11 +1,13 @@
|
||||
// For now this crate shares its error type with candle-core. We may introduce some separate
|
||||
// error type if needed or add some specialized cases on the candle-core side.
|
||||
mod activation;
|
||||
mod conv;
|
||||
mod embedding;
|
||||
mod layer_norm;
|
||||
mod linear;
|
||||
|
||||
pub use activation::Activation;
|
||||
pub use conv::{Conv1d, Conv1dConfig};
|
||||
pub use embedding::Embedding;
|
||||
pub use layer_norm::LayerNorm;
|
||||
pub use linear::Linear;
|
||||
|
Reference in New Issue
Block a user