From c3739d001bfb1a1305fc1ff398761e380f87bfb1 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 4 Jul 2023 14:06:09 +0100 Subject: [PATCH] Get the audio-encoder to return some values. --- candle-examples/examples/whisper/main.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 71b03e72..839dfc13 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -233,13 +233,15 @@ impl Conv1D { }) } - fn forward(&self, x: &Tensor) -> candle::Result { - let (bsize, _, _) = x.shape().r3()?; - let w = self.weight.broadcast_left(bsize)?.t()?; - let x = x.conv1d(&w, self.config.padding, self.config.stride)?; + fn forward(&self, x: &Tensor) -> Result { + let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?; match &self.bias { None => Ok(x), - Some(bias) => x.broadcast_add(bias), + Some(bias) => { + let b = bias.shape().r1()?; + let bias = bias.reshape((1, b, 1))?; + Ok(x.broadcast_add(&bias)?) + } } } } @@ -381,11 +383,9 @@ impl ResidualAttentionBlock { fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result { let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?; let mut x = (x + attn)?; - // Cross-Attn if let Some((attn, ln)) = &self.cross_attn { x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?; } - // Mlp let mlp = self.mlp_linear2.forward( &self .mlp_linear1 @@ -557,8 +557,8 @@ fn main() -> Result<()> { let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? }; let input = input.deserialize()?; - let x = input.tensor("x", &device)?.to_dtype(DType::U32)?; - let xa = input.tensor("xa", &device)?; + let tokens = input.tensor("tokens", &device)?.to_dtype(DType::U32)?; + let mel = input.tensor("mel", &device)?; let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; let weights = weights.deserialize()?; @@ -566,8 +566,8 @@ fn main() -> Result<()> { let cfg = Config::tiny_en(); let model = Whisper::load(&vb, &cfg)?; - let logits = model.decoder.forward(&x, &xa)?; + let logits = model.forward(&mel, &tokens)?; println!("{logits}"); - println!("python logits: {}", input.tensor("logits", &device)?); + println!("python logits: {}", input.tensor("dec", &device)?); Ok(()) }