From e6b01d0c18d27b2363f5aad7a19da38afc51f7d1 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 4 Jul 2023 10:01:05 +0100 Subject: [PATCH] Add the conv1d layer (but not the op). --- candle-examples/examples/whisper/main.rs | 89 ++++++++++++++++++++++-- 1 file changed, 85 insertions(+), 4 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index dac6df84..75ab2189 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -172,6 +172,79 @@ impl Linear { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct ConvConfig { + padding: usize, + stride: usize, +} + +impl Default for ConvConfig { + fn default() -> Self { + Self { + padding: 0, + stride: 1, + } + } +} + +struct Conv1D { + weight: Tensor, + bias: Option, + config: ConvConfig, +} + +impl Conv1D { + fn load( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: ConvConfig, + p: &str, + vb: &VarBuilder, + ) -> Result { + let weight = vb.get( + (out_channels, in_channels, kernel_size), + &format!("{p}.weight"), + )?; + let bias = vb.get(out_channels, &format!("{p}.bias"))?; + Ok(Self { + weight, + bias: Some(bias), + config, + }) + } + + fn load_no_bias( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: ConvConfig, + p: &str, + vb: &VarBuilder, + ) -> Result { + let weight = vb.get( + (out_channels, in_channels, kernel_size), + &format!("{p}.weight"), + )?; + Ok(Self { + weight, + bias: None, + config, + }) + } + + fn forward(&self, x: &Tensor) -> candle::Result { + let (bsize, _, _) = x.shape().r3()?; + let w = self.weight.broadcast_left(bsize)?.t()?; + // TODO: Add the conv1d operation + let x = x.matmul(&w)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + struct Dropout { pr: f64, } @@ -341,8 +414,8 @@ fn sinusoids(length: usize, channels: usize) -> Result { // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 struct AudioEncoder { - conv1: Linear, // TODO - conv2: Linear, // TODO + conv1: Conv1D, + conv2: Conv1D, positional_embedding: Tensor, blocks: Vec, ln_post: LayerNorm, @@ -352,8 +425,16 @@ impl AudioEncoder { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let n_state = cfg.n_audio_state; let n_head = cfg.n_audio_head; - let conv1 = Linear::load(cfg.n_mels, n_state, &format!("{p}.conv1"), vb)?; - let conv2 = Linear::load(n_state, n_state, &format!("{p}.conv2"), vb)?; + let cfg1 = ConvConfig { + padding: 1, + stride: 1, + }; + let cfg2 = ConvConfig { + padding: 1, + stride: 2, + }; + let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?; + let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?; let positional_embedding = sinusoids(cfg.n_audio_ctx, n_state)?.to_device(&vb.device)?; let blocks = (0..cfg.n_audio_layer) .map(|i| {