Move the conv1d layer to candle_nn. (#117)

This commit is contained in:
Laurent Mazare
2023-07-10 11:02:06 +01:00
committed by GitHub
parent b06e1a7e54
commit 89a5b602a6
5 changed files with 122 additions and 134 deletions

View File

@ -2,7 +2,7 @@
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use candle_nn::{Embedding, LayerNorm, Linear};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear};
use serde::Deserialize;
use std::collections::HashMap;
@ -112,78 +112,35 @@ fn linear_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Resul
Ok(Linear::new(weight, None))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct ConvConfig {
padding: usize,
stride: usize,
fn conv1d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: Conv1dConfig,
p: &str,
vb: &VarBuilder,
) -> Result<Conv1d> {
let weight = vb.get(
(out_channels, in_channels, kernel_size),
&format!("{p}.weight"),
)?;
let bias = vb.get(out_channels, &format!("{p}.bias"))?;
Ok(Conv1d::new(weight, Some(bias), config))
}
impl Default for ConvConfig {
fn default() -> Self {
Self {
padding: 0,
stride: 1,
}
}
}
struct Conv1D {
weight: Tensor,
bias: Option<Tensor>,
config: ConvConfig,
}
impl Conv1D {
fn load(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: ConvConfig,
p: &str,
vb: &VarBuilder,
) -> Result<Self> {
let weight = vb.get(
(out_channels, in_channels, kernel_size),
&format!("{p}.weight"),
)?;
let bias = vb.get(out_channels, &format!("{p}.bias"))?;
Ok(Self {
weight,
bias: Some(bias),
config,
})
}
fn load_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: ConvConfig,
p: &str,
vb: &VarBuilder,
) -> Result<Self> {
let weight = vb.get(
(out_channels, in_channels, kernel_size),
&format!("{p}.weight"),
)?;
Ok(Self {
weight,
bias: None,
config,
})
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
match &self.bias {
None => Ok(x),
Some(bias) => {
let b = bias.shape().r1()?;
let bias = bias.reshape((1, b, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
}
fn conv1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: Conv1dConfig,
p: &str,
vb: &VarBuilder,
) -> Result<Conv1d> {
let weight = vb.get(
(out_channels, in_channels, kernel_size),
&format!("{p}.weight"),
)?;
Ok(Conv1d::new(weight, None, config))
}
struct Dropout {
@ -338,8 +295,8 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
pub struct AudioEncoder {
conv1: Conv1D,
conv2: Conv1D,
conv1: Conv1d,
conv2: Conv1d,
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln_post: LayerNorm,
@ -350,15 +307,15 @@ impl AudioEncoder {
let n_state = cfg.d_model;
let n_head = cfg.encoder_attention_heads;
let n_ctx = cfg.max_source_positions;
let cfg1 = ConvConfig {
let cfg1 = Conv1dConfig {
padding: 1,
stride: 1,
};
let cfg2 = ConvConfig {
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
};
let conv1 = Conv1D::load(
let conv1 = conv1d(
cfg.num_mel_bins,
n_state,
3,
@ -366,7 +323,7 @@ impl AudioEncoder {
&format!("{p}.conv1"),
vb,
)?;
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?;
let blocks = (0..cfg.encoder_layers)
.map(|i| {