diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index e0292e17..3e780a2c 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -1,5 +1,7 @@ #![allow(dead_code)] // https://github.com/openai/whisper/blob/main/whisper/model.py +// TODO: +// - kv-cache support? use anyhow::{Error as E, Result}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; @@ -309,6 +311,7 @@ impl ResidualAttentionBlock { struct AudioEncoder { conv1: Linear, // TODO conv2: Linear, // TODO + positional_embedding: Tensor, blocks: Vec, ln_post: LayerNorm, } @@ -319,6 +322,7 @@ impl AudioEncoder { let n_head = cfg.n_audio_head; let conv1 = Linear::load(cfg.n_mels, n_state, &format!("{p}.conv1"), vb)?; let conv2 = Linear::load(n_state, n_state, &format!("{p}.conv2"), vb)?; + let positional_embedding = Tensor::new(&[0u32], &vb.device)?; // TODO let blocks = (0..cfg.n_audio_layer) .map(|i| { ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb) @@ -328,12 +332,21 @@ impl AudioEncoder { Ok(Self { conv1, conv2, + positional_embedding, blocks, ln_post, }) } - fn forward(&self, _x: &Tensor) -> Result { - todo!() + fn forward(&self, x: &Tensor) -> Result { + let x = self.conv1.forward(x)?.gelu()?; + let x = self.conv2.forward(&x)?.gelu()?; + let x = x.transpose(1, 2)?; + let mut x = x.broadcast_add(&self.positional_embedding)?; + for block in self.blocks.iter() { + x = block.forward(&x)? + } + let x = self.ln_post.forward(&x)?; + Ok(x) } } @@ -357,7 +370,7 @@ impl TextDecoder { }) .collect::>>()?; let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?; - let mask = Tensor::new(&[0u32], &vb.device)?; + let mask = Tensor::new(&[0u32], &vb.device)?; // TODO Ok(Self { token_embedding, blocks,