mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add more to the whisper inference.
This commit is contained in:
@ -1,5 +1,7 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
// https://github.com/openai/whisper/blob/main/whisper/model.py
|
// https://github.com/openai/whisper/blob/main/whisper/model.py
|
||||||
|
// TODO:
|
||||||
|
// - kv-cache support?
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||||
@ -309,6 +311,7 @@ impl ResidualAttentionBlock {
|
|||||||
struct AudioEncoder {
|
struct AudioEncoder {
|
||||||
conv1: Linear, // TODO
|
conv1: Linear, // TODO
|
||||||
conv2: Linear, // TODO
|
conv2: Linear, // TODO
|
||||||
|
positional_embedding: Tensor,
|
||||||
blocks: Vec<ResidualAttentionBlock>,
|
blocks: Vec<ResidualAttentionBlock>,
|
||||||
ln_post: LayerNorm,
|
ln_post: LayerNorm,
|
||||||
}
|
}
|
||||||
@ -319,6 +322,7 @@ impl AudioEncoder {
|
|||||||
let n_head = cfg.n_audio_head;
|
let n_head = cfg.n_audio_head;
|
||||||
let conv1 = Linear::load(cfg.n_mels, n_state, &format!("{p}.conv1"), vb)?;
|
let conv1 = Linear::load(cfg.n_mels, n_state, &format!("{p}.conv1"), vb)?;
|
||||||
let conv2 = Linear::load(n_state, n_state, &format!("{p}.conv2"), vb)?;
|
let conv2 = Linear::load(n_state, n_state, &format!("{p}.conv2"), vb)?;
|
||||||
|
let positional_embedding = Tensor::new(&[0u32], &vb.device)?; // TODO
|
||||||
let blocks = (0..cfg.n_audio_layer)
|
let blocks = (0..cfg.n_audio_layer)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
||||||
@ -328,12 +332,21 @@ impl AudioEncoder {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
conv1,
|
conv1,
|
||||||
conv2,
|
conv2,
|
||||||
|
positional_embedding,
|
||||||
blocks,
|
blocks,
|
||||||
ln_post,
|
ln_post,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
fn forward(&self, _x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
todo!()
|
let x = self.conv1.forward(x)?.gelu()?;
|
||||||
|
let x = self.conv2.forward(&x)?.gelu()?;
|
||||||
|
let x = x.transpose(1, 2)?;
|
||||||
|
let mut x = x.broadcast_add(&self.positional_embedding)?;
|
||||||
|
for block in self.blocks.iter() {
|
||||||
|
x = block.forward(&x)?
|
||||||
|
}
|
||||||
|
let x = self.ln_post.forward(&x)?;
|
||||||
|
Ok(x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -357,7 +370,7 @@ impl TextDecoder {
|
|||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
|
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
|
||||||
let mask = Tensor::new(&[0u32], &vb.device)?;
|
let mask = Tensor::new(&[0u32], &vb.device)?; // TODO
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
token_embedding,
|
token_embedding,
|
||||||
blocks,
|
blocks,
|
||||||
|
Reference in New Issue
Block a user