feat: support microphone whisper streaming (#1678)

* feat: support microphone whisper streaming

* fix: cleanup print stmts and adjust how input is read

* fix: remove incorrect comment

* feat: split into new example and simplify

* fix: feature flag example file

* fix: fmt fixes

* feat: simplify and remove redundant files
This commit is contained in:
drbh
2024-02-12 12:01:21 -05:00
committed by GitHub
parent d0aa197b07
commit 13c67226e6
5 changed files with 869 additions and 0 deletions

View File

@ -129,6 +129,10 @@ impl MultiHeadAttention {
.flatten_from(2)?;
Ok(wv)
}
fn reset_kv_cache(&mut self) {
self.kv_cache = None;
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
@ -193,6 +197,13 @@ impl ResidualAttentionBlock {
)?;
x + mlp
}
fn reset_kv_cache(&mut self) {
self.attn.reset_kv_cache();
if let Some((attn, _)) = &mut self.cross_attn {
attn.reset_kv_cache();
}
}
}
fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
@ -350,6 +361,12 @@ impl TextDecoder {
};
Ok(logits)
}
pub fn reset_kv_cache(&mut self) {
for block in self.blocks.iter_mut() {
block.reset_kv_cache();
}
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
@ -370,4 +387,12 @@ impl Whisper {
config,
})
}
pub fn reset_kv_cache(&mut self) {
self.encoder
.blocks
.iter_mut()
.for_each(|b| b.reset_kv_cache());
self.decoder.reset_kv_cache();
}
}

View File

@ -126,6 +126,10 @@ impl MultiHeadAttention {
.flatten_from(2)?;
Ok(wv)
}
fn reset_kv_cache(&mut self) {
self.kv_cache = None;
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
@ -189,6 +193,13 @@ impl ResidualAttentionBlock {
.apply(&self.mlp_linear2)?;
x + mlp
}
fn reset_kv_cache(&mut self) {
self.attn.reset_kv_cache();
if let Some((attn, _)) = &mut self.cross_attn {
attn.reset_kv_cache();
}
}
}
fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
@ -281,6 +292,12 @@ impl AudioEncoder {
let x = self.ln_post.forward(&x)?;
Ok(x)
}
pub fn reset_kv_cache(&mut self) {
for block in self.blocks.iter_mut() {
block.reset_kv_cache();
}
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
@ -348,6 +365,12 @@ impl TextDecoder {
};
Ok(logits)
}
pub fn reset_kv_cache(&mut self) {
for block in self.blocks.iter_mut() {
block.reset_kv_cache();
}
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
@ -368,4 +391,9 @@ impl Whisper {
config,
})
}
pub fn reset_kv_cache(&mut self) {
self.encoder.reset_kv_cache();
self.decoder.reset_kv_cache();
}
}