Add a quantized variant of whisper (#1017)

* Add the quantized-whisper model.

* Quantized the whisper model.

* Adapt the whisper example to handle quantization.

* Add the quantized flag.

* Load the proper weights.
This commit is contained in:
Laurent Mazare
2023-10-02 14:59:53 +01:00
committed by GitHub
parent 263a172202
commit e04c789230
5 changed files with 519 additions and 62 deletions

View File

@ -1,4 +1,3 @@
use crate::Whisper;
use candle::{IndexOp, Result, Tensor, D};
use tokenizers::Tokenizer;
@ -105,20 +104,28 @@ const LANGUAGES: [(&str, &str); 99] = [
];
/// Returns the token id for the selected language.
pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
pub fn detect_language(
model: &mut super::Model,
tokenizer: &Tokenizer,
mel: &Tensor,
) -> Result<u32> {
let (_bsize, _, seq_len) = mel.dims3()?;
let mel = mel.narrow(2, 0, usize::min(seq_len, model.config.max_source_positions))?;
let mel = mel.narrow(
2,
0,
usize::min(seq_len, model.config().max_source_positions),
)?;
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
.collect::<Result<Vec<_>>>()?;
let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
let audio_features = model.encoder.forward(&mel, true)?;
let audio_features = model.encoder_forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
let ys = model.decoder.forward(&tokens, &audio_features, true)?;
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;