Flesh out the whisper example.

This commit is contained in:
laurent
2023-07-04 08:01:06 +01:00
parent 0ca2af6940
commit 512dd9e4d6

View File

@ -305,6 +305,90 @@ impl ResidualAttentionBlock {
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
struct AudioEncoder {
conv1: Linear, // TODO
conv2: Linear, // TODO
blocks: Vec<ResidualAttentionBlock>,
ln_post: LayerNorm,
}
impl AudioEncoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.n_audio_state;
let n_head = cfg.n_audio_head;
let conv1 = Linear::load(cfg.n_mels, n_state, &format!("{p}.conv1"), vb)?;
let conv2 = Linear::load(n_state, n_state, &format!("{p}.conv2"), vb)?;
let blocks = (0..cfg.n_audio_layer)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
})
.collect::<Result<Vec<_>>>()?;
let ln_post = LayerNorm::load(n_state, &format!("{p}.ln_post"), vb)?;
Ok(Self {
conv1,
conv2,
blocks,
ln_post,
})
}
fn forward(&self, _x: &Tensor) -> Result<Tensor> {
todo!()
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
struct TextDecoder {
token_embedding: Embedding,
blocks: Vec<ResidualAttentionBlock>,
ln: LayerNorm,
mask: Tensor,
}
impl TextDecoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.n_text_state;
let n_head = cfg.n_text_head;
let token_embedding =
Embedding::load(cfg.n_vocab, n_state, &format!("{p}.token_embedding"), vb)?;
let blocks = (0..cfg.n_text_layer)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
})
.collect::<Result<Vec<_>>>()?;
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
let mask = Tensor::new(&[0u32], &vb.device)?;
Ok(Self {
token_embedding,
blocks,
ln,
mask,
})
}
fn forward(&self, _tokens: &Tensor, _enc: &Tensor) -> Result<Tensor> {
todo!()
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
struct Whisper {
encoder: AudioEncoder,
decoder: TextDecoder,
}
impl Whisper {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let encoder = AudioEncoder::load(&format!("{p}.encoder"), vb, cfg)?;
let decoder = TextDecoder::load(&format!("{p}.decoder"), vb, cfg)?;
Ok(Self { encoder, decoder })
}
fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
let enc = self.encoder.forward(mel)?;
let dec = self.decoder.forward(tokens, &enc)?;
Ok(dec)
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {