From e04c789230c609c285991b78c29f1d6eef0d104f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 2 Oct 2023 14:59:53 +0100 Subject: [PATCH] Add a quantized variant of whisper (#1017) * Add the quantized-whisper model. * Quantized the whisper model. * Adapt the whisper example to handle quantization. * Add the quantized flag. * Load the proper weights. --- candle-examples/examples/whisper/main.rs | 120 ++++-- .../examples/whisper/multilingual.rs | 19 +- candle-transformers/src/models/whisper/mod.rs | 20 + .../src/models/whisper/model.rs | 19 +- .../src/models/whisper/quantized_model.rs | 403 ++++++++++++++++++ 5 files changed, 519 insertions(+), 62 deletions(-) create mode 100644 candle-transformers/src/models/whisper/quantized_model.rs diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 0aa4db41..5249ed34 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -18,8 +18,48 @@ use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; mod multilingual; -use candle_transformers::models::whisper::{self as m, audio, model}; -use model::{Config, Whisper}; +use candle_transformers::models::whisper::{self as m, audio, Config}; + +pub enum Model { + Normal(m::model::Whisper), + Quantized(m::quantized_model::Whisper), +} + +// Maybe we should use some traits rather than doing the dispatch for all these. +impl Model { + pub fn config(&self) -> &Config { + match self { + Self::Normal(m) => &m.config, + Self::Quantized(m) => &m.config, + } + } + + pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result { + match self { + Self::Normal(m) => m.encoder.forward(x, flush), + Self::Quantized(m) => m.encoder.forward(x, flush), + } + } + + pub fn decoder_forward( + &mut self, + x: &Tensor, + xa: &Tensor, + flush: bool, + ) -> candle::Result { + match self { + Self::Normal(m) => m.decoder.forward(x, xa, flush), + Self::Quantized(m) => m.decoder.forward(x, xa, flush), + } + } + + pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result { + match self { + Self::Normal(m) => m.decoder.final_linear(x), + Self::Quantized(m) => m.decoder.final_linear(x), + } + } +} #[allow(dead_code)] #[derive(Debug, Clone)] @@ -41,7 +81,7 @@ struct Segment { } struct Decoder { - model: Whisper, + model: Model, rng: rand::rngs::StdRng, task: Option, timestamps: bool, @@ -60,7 +100,7 @@ struct Decoder { impl Decoder { #[allow(clippy::too_many_arguments)] fn new( - model: Whisper, + model: Model, tokenizer: Tokenizer, seed: u64, device: &Device, @@ -72,9 +112,9 @@ impl Decoder { let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?; // Suppress the notimestamps token when in timestamps mode. // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452 - let suppress_tokens: Vec = (0..model.config.vocab_size as u32) + let suppress_tokens: Vec = (0..model.config().vocab_size as u32) .map(|i| { - if model.config.suppress_tokens.contains(&i) + if model.config().suppress_tokens.contains(&i) || timestamps && i == no_timestamps_token { f32::NEG_INFINITY @@ -109,11 +149,11 @@ impl Decoder { fn decode(&mut self, mel: &Tensor, t: f64) -> Result { let model = &mut self.model; - let audio_features = model.encoder.forward(mel, true)?; + let audio_features = model.encoder_forward(mel, true)?; if self.verbose { println!("audio features: {:?}", audio_features.dims()); } - let sample_len = model.config.max_target_positions / 2; + let sample_len = model.config().max_target_positions / 2; let mut sum_logprob = 0f64; let mut no_speech_prob = f64::NAN; let mut tokens = vec![self.sot_token]; @@ -133,12 +173,12 @@ impl Decoder { // The model expects a batch dim but this inference loop does not handle // it so we add it at this point. let tokens_t = tokens_t.unsqueeze(0)?; - let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?; + let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?; // Extract the no speech probability on the first iteration by looking at the first // token logits and the probability for the according token. if i == 0 { - let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?; + let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?; no_speech_prob = softmax(&logits, 0)? .i(self.no_speech_token as usize)? .to_scalar::()? as f64; @@ -146,8 +186,7 @@ impl Decoder { let (_, seq_len, _) = ys.dims3()?; let logits = model - .decoder - .final_linear(&ys.i((..1, seq_len - 1..))?)? + .decoder_final_linear(&ys.i((..1, seq_len - 1..))?)? .i(0)? .i(0)?; // TODO: Besides suppress tokens, we should apply the heuristics from @@ -176,7 +215,7 @@ impl Decoder { let prob = softmax(&logits, candle::D::Minus1)? .i(next_token as usize)? .to_scalar::()? as f64; - if next_token == self.eot_token || tokens.len() > model.config.max_target_positions { + if next_token == self.eot_token || tokens.len() > model.config().max_target_positions { break; } sum_logprob += prob.ln(); @@ -333,6 +372,7 @@ impl WhichModel { Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false, } } + fn model_and_revision(&self) -> (&'static str, &'static str) { match self { Self::Tiny => ("openai/whisper-tiny", "main"), @@ -382,6 +422,9 @@ struct Args { #[arg(long)] tracing: bool, + #[arg(long)] + quantized: bool, + /// Language. #[arg(long)] language: Option, @@ -413,10 +456,13 @@ fn main() -> Result<()> { None }; let device = candle_examples::device(args.cpu)?; - let (default_model, default_revision) = args.model.model_and_revision(); + let (default_model, default_revision) = if args.quantized { + ("lmz/candle-whisper", "main") + } else { + args.model.model_and_revision() + }; let default_model = default_model.to_string(); let default_revision = default_revision.to_string(); - let path = std::path::PathBuf::from(default_model.clone()); let (model_id, revision) = match (args.model_id, args.revision) { (Some(model_id), Some(revision)) => (model_id, revision), (Some(model_id), None) => (model_id, "main".to_string()), @@ -424,20 +470,7 @@ fn main() -> Result<()> { (None, None) => (default_model, default_revision), }; - let (config_filename, tokenizer_filename, weights_filename, input) = if path.exists() { - let mut config_filename = path.clone(); - config_filename.push("config.json"); - let mut tokenizer_filename = path.clone(); - tokenizer_filename.push("tokenizer.json"); - let mut model_filename = path; - model_filename.push("model.safetensors"); - ( - config_filename, - tokenizer_filename, - model_filename, - std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")), - ) - } else { + let (config_filename, tokenizer_filename, weights_filename, input) = { let api = Api::new()?; let dataset = api.dataset("Narsil/candle-examples".to_string()); let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); @@ -451,12 +484,17 @@ fn main() -> Result<()> { println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav"); dataset.get("samples_jfk.wav")? }; - ( - repo.get("config.json")?, - repo.get("tokenizer.json")?, - repo.get("model.safetensors")?, - sample, - ) + let config = if args.quantized { + repo.get("config-tiny.json")? + } else { + repo.get("config.json")? + }; + let model = if args.quantized { + repo.get("model-tiny-q40.gguf")? + } else { + repo.get("model.safetensors")? + }; + (config, repo.get("tokenizer.json")?, model, sample) }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; @@ -481,10 +519,16 @@ fn main() -> Result<()> { let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?; println!("loaded mel: {:?}", mel.dims()); - let vb = - unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? }; let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; - let mut model = Whisper::load(&vb, config)?; + let mut model = if args.quantized { + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?; + Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) + } else { + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? }; + Model::Normal(m::model::Whisper::load(&vb, config)?) + }; let language_token = match (args.model.is_multilingual(), args.language) { (true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?), diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs index a82b09ef..9548ffe6 100644 --- a/candle-examples/examples/whisper/multilingual.rs +++ b/candle-examples/examples/whisper/multilingual.rs @@ -1,4 +1,3 @@ -use crate::Whisper; use candle::{IndexOp, Result, Tensor, D}; use tokenizers::Tokenizer; @@ -105,20 +104,28 @@ const LANGUAGES: [(&str, &str); 99] = [ ]; /// Returns the token id for the selected language. -pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result { +pub fn detect_language( + model: &mut super::Model, + tokenizer: &Tokenizer, + mel: &Tensor, +) -> Result { let (_bsize, _, seq_len) = mel.dims3()?; - let mel = mel.narrow(2, 0, usize::min(seq_len, model.config.max_source_positions))?; + let mel = mel.narrow( + 2, + 0, + usize::min(seq_len, model.config().max_source_positions), + )?; let device = mel.device(); let language_token_ids = LANGUAGES .iter() .map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>"))) .collect::>>()?; let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?; - let audio_features = model.encoder.forward(&mel, true)?; + let audio_features = model.encoder_forward(&mel, true)?; let tokens = Tensor::new(&[[sot_token]], device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; - let ys = model.decoder.forward(&tokens, &audio_features, true)?; - let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?; + let ys = model.decoder_forward(&tokens, &audio_features, true)?; + let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?; let logits = logits.index_select(&language_token_ids, 0)?; let probs = candle_nn::ops::softmax(&logits, D::Minus1)?; let probs = probs.to_vec1::()?; diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs index 7dc8107b..35d35e77 100644 --- a/candle-transformers/src/models/whisper/mod.rs +++ b/candle-transformers/src/models/whisper/mod.rs @@ -1,5 +1,25 @@ pub mod audio; pub mod model; +pub mod quantized_model; + +use serde::Deserialize; + +// The names in comments correspond to the original implementation: +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub num_mel_bins: usize, // n_mels + pub max_source_positions: usize, // n_audio_ctx + pub d_model: usize, // n_audio_state + pub encoder_attention_heads: usize, // n_audio_head + pub encoder_layers: usize, // n_audio_layer + pub vocab_size: usize, // n_vocab + pub max_target_positions: usize, // n_text_ctx + // pub n_text_state: usize, + pub decoder_attention_heads: usize, // n_text_head + pub decoder_layers: usize, // n_text_layer + pub suppress_tokens: Vec, +} pub const DTYPE: candle::DType = candle::DType::F32; diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index d2eda796..2a58afaf 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -1,23 +1,6 @@ +use super::Config; use candle::{Device, IndexOp, Result, Tensor, D}; use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; -use serde::Deserialize; - -// The names in comments correspond to the original implementation: -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub struct Config { - pub num_mel_bins: usize, // n_mels - pub max_source_positions: usize, // n_audio_ctx - pub d_model: usize, // n_audio_state - pub encoder_attention_heads: usize, // n_audio_head - pub encoder_layers: usize, // n_audio_layer - pub vocab_size: usize, // n_vocab - pub max_target_positions: usize, // n_text_ctx - // pub n_text_state: usize, - pub decoder_attention_heads: usize, // n_text_head - pub decoder_layers: usize, // n_text_layer - pub suppress_tokens: Vec, -} fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { let embeddings = vb.get((vocab_size, hidden_size), "weight")?; diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs new file mode 100644 index 00000000..59942cbf --- /dev/null +++ b/candle-transformers/src/models/whisper/quantized_model.rs @@ -0,0 +1,403 @@ +use super::Config; +use crate::models::{quantized_t5::Embedding, with_tracing::QMatMul}; +pub use crate::quantized_var_builder::VarBuilder; +use candle::{Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Conv1d, Conv1dConfig, LayerNorm, Module}; + +#[derive(Debug)] +struct Linear { + weight: QMatMul, + bias: Option, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> candle::Result { + let x = x.apply(&self.weight)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + +fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { + let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?; + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { + weight, + bias: Some(bias), + }) +} + +fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { weight, bias: None }) +} + +fn conv1d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: Conv1dConfig, + vb: VarBuilder, +) -> Result { + let weight = vb + .get((out_channels, in_channels, kernel_size), "weight")? + .dequantize(vb.device())?; + let bias = vb.get(out_channels, "bias")?.dequantize(vb.device())?; + Ok(Conv1d::new(weight, Some(bias), config)) +} + +fn layer_norm(size: usize, vb: VarBuilder) -> Result { + let weight = vb.get(size, "weight")?.dequantize(vb.device())?; + let bias = vb.get(size, "bias")?.dequantize(vb.device())?; + Ok(candle_nn::LayerNorm::new(weight, bias, 1e-5)) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +struct MultiHeadAttention { + query: Linear, + key: Linear, + value: Linear, + out: Linear, + n_head: usize, + span: tracing::Span, + softmax_span: tracing::Span, + matmul_span: tracing::Span, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl MultiHeadAttention { + fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn"); + let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax"); + let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul"); + let query = linear(n_state, n_state, vb.pp("q_proj"))?; + let value = linear(n_state, n_state, vb.pp("v_proj"))?; + let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; + let out = linear(n_state, n_state, vb.pp("out_proj"))?; + Ok(Self { + query, + key, + value, + out, + n_head, + span, + softmax_span, + matmul_span, + kv_cache: None, + }) + } + + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_cache: bool, + ) -> Result { + let _enter = self.span.enter(); + let q = self.query.forward(x)?; + let (k, v) = match xa { + None => { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + (k, v) + } + Some(x) => { + if flush_cache { + self.kv_cache = None; + } + if let Some((k, v)) = &self.kv_cache { + (k.clone(), v.clone()) + } else { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + self.kv_cache = Some((k.clone(), v.clone())); + (k, v) + } + } + }; + let wv = self.qkv_attention(&q, &k, &v, mask)?; + let out = self.out.forward(&wv)?; + Ok(out) + } + + fn reshape_head(&self, x: &Tensor) -> Result { + let (n_batch, n_ctx, n_state) = x.dims3()?; + let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; + x.reshape(target_dims)?.transpose(1, 2) + } + + fn qkv_attention( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + ) -> Result { + let (_, n_ctx, n_state) = q.dims3()?; + let scale = ((n_state / self.n_head) as f64).powf(-0.25); + let q = (self.reshape_head(q)? * scale)?; + let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; + let v = self.reshape_head(v)?.contiguous()?; + let mut qk = { + let _enter = self.matmul_span.enter(); + q.matmul(&k)? + }; + if let Some(mask) = mask { + let mask = mask.i((0..n_ctx, 0..n_ctx))?; + qk = qk.broadcast_add(&mask)? + } + let w = { + let _enter = self.softmax_span.enter(); + candle_nn::ops::softmax_last_dim(&qk)? + }; + let wv = { + let _enter = self.matmul_span.enter(); + w.matmul(&v)? + } + .transpose(1, 2)? + .flatten_from(2)?; + Ok(wv) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +struct ResidualAttentionBlock { + attn: MultiHeadAttention, + attn_ln: LayerNorm, + cross_attn: Option<(MultiHeadAttention, LayerNorm)>, + mlp_linear1: Linear, + mlp_linear2: Linear, + mlp_ln: LayerNorm, + span: tracing::Span, +} + +impl ResidualAttentionBlock { + fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "residual-attn"); + let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; + let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; + let cross_attn = if ca { + let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; + let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; + Some((cross_attn, cross_attn_ln)) + } else { + None + }; + let n_mlp = n_state * 4; + let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; + let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; + let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; + Ok(Self { + attn, + attn_ln, + cross_attn, + mlp_linear1, + mlp_linear2, + mlp_ln, + span, + }) + } + + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_kv_cache: bool, + ) -> Result { + let _enter = self.span.enter(); + let attn = self + .attn + .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?; + let mut x = (x + attn)?; + if let Some((attn, ln)) = &mut self.cross_attn { + x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; + } + let mlp = self.mlp_linear2.forward( + &self + .mlp_linear1 + .forward(&self.mlp_ln.forward(&x)?)? + .gelu()?, + )?; + x + mlp + } +} + +fn sinusoids(length: usize, channels: usize) -> Result { + let max_timescale = 10000f32; + let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; + let inv_timescales: Vec<_> = (0..channels / 2) + .map(|i| (i as f32 * (-log_timescale_increment)).exp()) + .collect(); + let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let arange = Tensor::arange(0, length as u32, &Device::Cpu)? + .to_dtype(candle::DType::F32)? + .unsqueeze(1)?; + let sh = (length, channels / 2); + let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; + let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; + Ok(sincos) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +pub struct AudioEncoder { + conv1: Conv1d, + conv2: Conv1d, + positional_embedding: Tensor, + blocks: Vec, + ln_post: LayerNorm, + span: tracing::Span, + conv1_span: tracing::Span, + conv2_span: tracing::Span, +} + +impl AudioEncoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "audio-encoder"); + let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1"); + let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2"); + let n_state = cfg.d_model; + let n_head = cfg.encoder_attention_heads; + let n_ctx = cfg.max_source_positions; + let cfg1 = Conv1dConfig { + padding: 1, + stride: 1, + groups: 1, + dilation: 1, + }; + let cfg2 = Conv1dConfig { + padding: 1, + stride: 2, + groups: 1, + dilation: 1, + }; + let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; + let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; + let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; + let blocks = (0..cfg.encoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}"))) + }) + .collect::>>()?; + let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; + Ok(Self { + conv1, + conv2, + positional_embedding, + blocks, + ln_post, + conv1_span, + conv2_span, + span, + }) + } + + pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result { + let _enter = self.span.enter(); + let x = { + let _enter = self.conv1_span.enter(); + self.conv1.forward(x)?.gelu()? + }; + let x = { + let _enter = self.conv2_span.enter(); + self.conv2.forward(&x)?.gelu()? + }; + let x = x.transpose(1, 2)?; + let (_bsize, seq_len, _hidden) = x.dims3()?; + let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; + let mut x = x.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter_mut() { + x = block.forward(&x, None, None, flush_kv_cache)? + } + let x = self.ln_post.forward(&x)?; + Ok(x) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +pub struct TextDecoder { + token_embedding: Embedding, + positional_embedding: Tensor, + blocks: Vec, + ln: LayerNorm, + mask: Tensor, + span: tracing::Span, + span_final: tracing::Span, +} + +impl TextDecoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "text-decoder"); + let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final"); + let n_state = cfg.d_model; + let n_head = cfg.decoder_attention_heads; + let n_ctx = cfg.max_target_positions; + let token_embedding = Embedding::new(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; + let positional_embedding = vb + .get((n_ctx, n_state), "embed_positions.weight")? + .dequantize(vb.device())?; + let blocks = (0..cfg.decoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}"))) + }) + .collect::>>()?; + let ln = layer_norm(n_state, vb.pp("layer_norm"))?; + let mask: Vec<_> = (0..n_ctx) + .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; + Ok(Self { + token_embedding, + positional_embedding, + blocks, + ln, + mask, + span, + span_final, + }) + } + + pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result { + let _enter = self.span.enter(); + let last = x.dim(D::Minus1)?; + let token_embedding = self.token_embedding.forward(x)?; + let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; + let mut x = token_embedding.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter_mut() { + x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; + } + self.ln.forward(&x) + } + + pub fn final_linear(&self, x: &Tensor) -> Result { + let b_size = x.dim(0)?; + let w = self.token_embedding.embeddings().broadcast_left(b_size)?; + let logits = { + let _enter = self.span_final.enter(); + x.matmul(&w.t()?)? + }; + Ok(logits) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +pub struct Whisper { + pub encoder: AudioEncoder, + pub decoder: TextDecoder, + pub config: Config, +} + +impl Whisper { + pub fn load(vb: &VarBuilder, config: Config) -> Result { + let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; + let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; + Ok(Self { + encoder, + decoder, + config, + }) + } +}