Add a quantized variant of whisper (#1017)

* Add the quantized-whisper model.

* Quantized the whisper model.

* Adapt the whisper example to handle quantization.

* Add the quantized flag.

* Load the proper weights.
This commit is contained in:
Laurent Mazare
2023-10-02 14:59:53 +01:00
committed by GitHub
parent 263a172202
commit e04c789230
5 changed files with 519 additions and 62 deletions

View File

@ -1,5 +1,25 @@
pub mod audio;
pub mod model;
pub mod quantized_model;
use serde::Deserialize;
// The names in comments correspond to the original implementation:
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub num_mel_bins: usize, // n_mels
pub max_source_positions: usize, // n_audio_ctx
pub d_model: usize, // n_audio_state
pub encoder_attention_heads: usize, // n_audio_head
pub encoder_layers: usize, // n_audio_layer
pub vocab_size: usize, // n_vocab
pub max_target_positions: usize, // n_text_ctx
// pub n_text_state: usize,
pub decoder_attention_heads: usize, // n_text_head
pub decoder_layers: usize, // n_text_layer
pub suppress_tokens: Vec<u32>,
}
pub const DTYPE: candle::DType = candle::DType::F32;