Add more to the whisper inference.

This commit is contained in:
laurent
2023-07-04 08:32:03 +01:00
parent 512dd9e4d6
commit 6728a85676

View File

@ -1,5 +1,7 @@
#![allow(dead_code)]
// https://github.com/openai/whisper/blob/main/whisper/model.py
// TODO:
// - kv-cache support?
use anyhow::{Error as E, Result};
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
@ -309,6 +311,7 @@ impl ResidualAttentionBlock {
struct AudioEncoder {
conv1: Linear, // TODO
conv2: Linear, // TODO
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln_post: LayerNorm,
}
@ -319,6 +322,7 @@ impl AudioEncoder {
let n_head = cfg.n_audio_head;
let conv1 = Linear::load(cfg.n_mels, n_state, &format!("{p}.conv1"), vb)?;
let conv2 = Linear::load(n_state, n_state, &format!("{p}.conv2"), vb)?;
let positional_embedding = Tensor::new(&[0u32], &vb.device)?; // TODO
let blocks = (0..cfg.n_audio_layer)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
@ -328,12 +332,21 @@ impl AudioEncoder {
Ok(Self {
conv1,
conv2,
positional_embedding,
blocks,
ln_post,
})
}
fn forward(&self, _x: &Tensor) -> Result<Tensor> {
todo!()
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.conv1.forward(x)?.gelu()?;
let x = self.conv2.forward(&x)?.gelu()?;
let x = x.transpose(1, 2)?;
let mut x = x.broadcast_add(&self.positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x)?
}
let x = self.ln_post.forward(&x)?;
Ok(x)
}
}
@ -357,7 +370,7 @@ impl TextDecoder {
})
.collect::<Result<Vec<_>>>()?;
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
let mask = Tensor::new(&[0u32], &vb.device)?;
let mask = Tensor::new(&[0u32], &vb.device)?; // TODO
Ok(Self {
token_embedding,
blocks,