use candle::{Device, IndexOp, Result, Tensor}; use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; use serde::Deserialize; // The names in comments correspond to the original implementation: // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { pub num_mel_bins: usize, // n_mels pub max_source_positions: usize, // n_audio_ctx pub d_model: usize, // n_audio_state pub encoder_attention_heads: usize, // n_audio_head pub encoder_layers: usize, // n_audio_layer pub vocab_size: usize, // n_vocab pub max_target_positions: usize, // n_text_ctx // pub n_text_state: usize, pub decoder_attention_heads: usize, // n_text_head pub decoder_layers: usize, // n_text_layer pub suppress_tokens: Vec, } impl Config { #[allow(dead_code)] pub fn tiny_en() -> Self { let suppress_tokens = vec![ 1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362, ]; Self { num_mel_bins: 80, vocab_size: 51864, max_source_positions: 1500, d_model: 384, encoder_attention_heads: 6, encoder_layers: 4, max_target_positions: 448, // n_text_state: 384, decoder_attention_heads: 6, decoder_layers: 4, suppress_tokens, } } } fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) } // // We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting // model. #[derive(Debug)] pub struct Linear { inner: candle_nn::Linear, span: tracing::Span, } impl Linear { fn forward(&self, x: &Tensor) -> Result { let _enter = self.span.enter(); self.inner.forward(x) } } fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { let span = tracing::span!(tracing::Level::TRACE, "linear"); let inner = candle_nn::linear(size1, size2, vb)?; Ok(Linear { inner, span }) } fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result { let span = tracing::span!(tracing::Level::TRACE, "linear"); let inner = candle_nn::linear_no_bias(size1, size2, vb)?; Ok(Linear { inner, span }) } fn conv1d( in_channels: usize, out_channels: usize, kernel_size: usize, config: Conv1dConfig, vb: VarBuilder, ) -> Result { let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; let bias = vb.get(out_channels, "bias")?; Ok(Conv1d::new(weight, Some(bias), config)) } fn layer_norm(size: usize, vb: VarBuilder) -> Result { let weight = vb.get(size, "weight")?; let bias = vb.get(size, "bias")?; Ok(LayerNorm::new(weight, bias, 1e-5)) } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 struct MultiHeadAttention { query: Linear, key: Linear, value: Linear, out: Linear, n_head: usize, span: tracing::Span, softmax_span: tracing::Span, matmul_span: tracing::Span, kv_cache: Option<(Tensor, Tensor)>, } impl MultiHeadAttention { fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result { let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn"); let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax"); let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul"); let query = linear(n_state, n_state, vb.pp("q_proj"))?; let value = linear(n_state, n_state, vb.pp("v_proj"))?; let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; let out = linear(n_state, n_state, vb.pp("out_proj"))?; Ok(Self { query, key, value, out, n_head, span, softmax_span, matmul_span, kv_cache: None, }) } fn forward( &mut self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>, flush_cache: bool, ) -> Result { let _enter = self.span.enter(); let q = self.query.forward(x)?; let (k, v) = match xa { None => { let k = self.key.forward(x)?; let v = self.value.forward(x)?; (k, v) } Some(x) => { if flush_cache { self.kv_cache = None; } if let Some((k, v)) = &self.kv_cache { (k.clone(), v.clone()) } else { let k = self.key.forward(x)?; let v = self.value.forward(x)?; self.kv_cache = Some((k.clone(), v.clone())); (k, v) } } }; let wv = self.qkv_attention(&q, &k, &v, mask)?; let out = self.out.forward(&wv)?; Ok(out) } fn reshape_head(&self, x: &Tensor) -> Result { let (n_batch, n_ctx, n_state) = x.dims3()?; let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; x.reshape(target_dims)?.transpose(1, 2) } fn qkv_attention( &self, q: &Tensor, k: &Tensor, v: &Tensor, mask: Option<&Tensor>, ) -> Result { let (_, n_ctx, n_state) = q.dims3()?; let scale = ((n_state / self.n_head) as f64).powf(-0.25); let q = (self.reshape_head(q)? * scale)?; let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; let v = self.reshape_head(v)?.contiguous()?; let mut qk = { let _enter = self.matmul_span.enter(); q.matmul(&k)? }; if let Some(mask) = mask { let mask = mask.i((0..n_ctx, 0..n_ctx))?; qk = qk.broadcast_add(&mask)? } let w = { let _enter = self.softmax_span.enter(); softmax(&qk, candle::D::Minus1)? }; let wv = { let _enter = self.matmul_span.enter(); w.matmul(&v)? } .transpose(1, 2)? .flatten_from(2)?; Ok(wv) } } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 struct ResidualAttentionBlock { attn: MultiHeadAttention, attn_ln: LayerNorm, cross_attn: Option<(MultiHeadAttention, LayerNorm)>, mlp_linear1: Linear, mlp_linear2: Linear, mlp_ln: LayerNorm, span: tracing::Span, } impl ResidualAttentionBlock { fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result { let span = tracing::span!(tracing::Level::TRACE, "residual-attn"); let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; let cross_attn = if ca { let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; Some((cross_attn, cross_attn_ln)) } else { None }; let n_mlp = n_state * 4; let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; Ok(Self { attn, attn_ln, cross_attn, mlp_linear1, mlp_linear2, mlp_ln, span, }) } fn forward( &mut self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>, flush_kv_cache: bool, ) -> Result { let _enter = self.span.enter(); let attn = self .attn .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?; let mut x = (x + attn)?; if let Some((attn, ln)) = &mut self.cross_attn { x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; } let mlp = self.mlp_linear2.forward( &self .mlp_linear1 .forward(&self.mlp_ln.forward(&x)?)? .gelu()?, )?; x + mlp } } fn sinusoids(length: usize, channels: usize) -> Result { let max_timescale = 10000f32; let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; let inv_timescales: Vec<_> = (0..channels / 2) .map(|i| (i as f32 * (-log_timescale_increment)).exp()) .collect(); let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; let arange = Tensor::arange(0, length as u32, &Device::Cpu)? .to_dtype(candle::DType::F32)? .unsqueeze(1)?; let sh = (length, channels / 2); let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; Ok(sincos) } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 pub struct AudioEncoder { conv1: Conv1d, conv2: Conv1d, positional_embedding: Tensor, blocks: Vec, ln_post: LayerNorm, span: tracing::Span, conv1_span: tracing::Span, conv2_span: tracing::Span, } impl AudioEncoder { fn load(vb: VarBuilder, cfg: &Config) -> Result { let span = tracing::span!(tracing::Level::TRACE, "audio-encoder"); let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1"); let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2"); let n_state = cfg.d_model; let n_head = cfg.encoder_attention_heads; let n_ctx = cfg.max_source_positions; let cfg1 = Conv1dConfig { padding: 1, stride: 1, groups: 1, }; let cfg2 = Conv1dConfig { padding: 1, stride: 2, groups: 1, }; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; let blocks = (0..cfg.encoder_layers) .map(|i| { ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) }) .collect::>>()?; let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; Ok(Self { conv1, conv2, positional_embedding, blocks, ln_post, conv1_span, conv2_span, span, }) } pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result { let _enter = self.span.enter(); let x = { let _enter = self.conv1_span.enter(); self.conv1.forward(x)?.gelu()? }; let x = { let _enter = self.conv2_span.enter(); self.conv2.forward(&x)?.gelu()? }; let x = x.transpose(1, 2)?; let (_bsize, seq_len, _hidden) = x.dims3()?; let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; let mut x = x.broadcast_add(&positional_embedding)?; for block in self.blocks.iter_mut() { x = block.forward(&x, None, None, flush_kv_cache)? } let x = self.ln_post.forward(&x)?; Ok(x) } } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 pub struct TextDecoder { token_embedding: Embedding, positional_embedding: Tensor, blocks: Vec, ln: LayerNorm, mask: Tensor, span: tracing::Span, span_final: tracing::Span, } impl TextDecoder { fn load(vb: VarBuilder, cfg: &Config) -> Result { let span = tracing::span!(tracing::Level::TRACE, "text-decoder"); let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final"); let n_state = cfg.d_model; let n_head = cfg.decoder_attention_heads; let n_ctx = cfg.max_target_positions; let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; let blocks = (0..cfg.decoder_layers) .map(|i| { ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) }) .collect::>>()?; let ln = layer_norm(n_state, vb.pp("layer_norm"))?; let mask: Vec<_> = (0..n_ctx) .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) .collect(); let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; Ok(Self { token_embedding, positional_embedding, blocks, ln, mask, span, span_final, }) } pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result { let _enter = self.span.enter(); let x_dims = x.dims(); let last = x_dims[x_dims.len() - 1]; let token_embedding = self.token_embedding.forward(x)?; let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; let mut x = token_embedding.broadcast_add(&positional_embedding)?; for block in self.blocks.iter_mut() { x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; } self.ln.forward(&x) } pub fn final_linear(&self, x: &Tensor) -> Result { let b_size = x.dim(0)?; let w = self.token_embedding.embeddings().broadcast_left(b_size)?; let logits = { let _enter = self.span_final.enter(); x.matmul(&w.t()?)? }; Ok(logits) } } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 pub struct Whisper { pub encoder: AudioEncoder, pub decoder: TextDecoder, pub config: Config, } impl Whisper { pub fn load(vb: &VarBuilder, config: Config) -> Result { let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; Ok(Self { encoder, decoder, config, }) } }