mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Flesh out the whisper example.
This commit is contained in:
@ -305,6 +305,90 @@ impl ResidualAttentionBlock {
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
||||
struct AudioEncoder {
|
||||
conv1: Linear, // TODO
|
||||
conv2: Linear, // TODO
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln_post: LayerNorm,
|
||||
}
|
||||
|
||||
impl AudioEncoder {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let n_state = cfg.n_audio_state;
|
||||
let n_head = cfg.n_audio_head;
|
||||
let conv1 = Linear::load(cfg.n_mels, n_state, &format!("{p}.conv1"), vb)?;
|
||||
let conv2 = Linear::load(n_state, n_state, &format!("{p}.conv2"), vb)?;
|
||||
let blocks = (0..cfg.n_audio_layer)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_post = LayerNorm::load(n_state, &format!("{p}.ln_post"), vb)?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
blocks,
|
||||
ln_post,
|
||||
})
|
||||
}
|
||||
fn forward(&self, _x: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
|
||||
struct TextDecoder {
|
||||
token_embedding: Embedding,
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln: LayerNorm,
|
||||
mask: Tensor,
|
||||
}
|
||||
|
||||
impl TextDecoder {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let n_state = cfg.n_text_state;
|
||||
let n_head = cfg.n_text_head;
|
||||
let token_embedding =
|
||||
Embedding::load(cfg.n_vocab, n_state, &format!("{p}.token_embedding"), vb)?;
|
||||
let blocks = (0..cfg.n_text_layer)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
|
||||
let mask = Tensor::new(&[0u32], &vb.device)?;
|
||||
Ok(Self {
|
||||
token_embedding,
|
||||
blocks,
|
||||
ln,
|
||||
mask,
|
||||
})
|
||||
}
|
||||
fn forward(&self, _tokens: &Tensor, _enc: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
|
||||
struct Whisper {
|
||||
encoder: AudioEncoder,
|
||||
decoder: TextDecoder,
|
||||
}
|
||||
|
||||
impl Whisper {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let encoder = AudioEncoder::load(&format!("{p}.encoder"), vb, cfg)?;
|
||||
let decoder = TextDecoder::load(&format!("{p}.decoder"), vb, cfg)?;
|
||||
Ok(Self { encoder, decoder })
|
||||
}
|
||||
fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
|
||||
let enc = self.encoder.forward(mel)?;
|
||||
let dec = self.decoder.forward(tokens, &enc)?;
|
||||
Ok(dec)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
|
Reference in New Issue
Block a user