mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add a quantized variant of whisper (#1017)
* Add the quantized-whisper model. * Quantized the whisper model. * Adapt the whisper example to handle quantization. * Add the quantized flag. * Load the proper weights.
This commit is contained in:
@ -1,5 +1,25 @@
|
||||
pub mod audio;
|
||||
pub mod model;
|
||||
pub mod quantized_model;
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub num_mel_bins: usize, // n_mels
|
||||
pub max_source_positions: usize, // n_audio_ctx
|
||||
pub d_model: usize, // n_audio_state
|
||||
pub encoder_attention_heads: usize, // n_audio_head
|
||||
pub encoder_layers: usize, // n_audio_layer
|
||||
pub vocab_size: usize, // n_vocab
|
||||
pub max_target_positions: usize, // n_text_ctx
|
||||
// pub n_text_state: usize,
|
||||
pub decoder_attention_heads: usize, // n_text_head
|
||||
pub decoder_layers: usize, // n_text_layer
|
||||
pub suppress_tokens: Vec<u32>,
|
||||
}
|
||||
|
||||
pub const DTYPE: candle::DType = candle::DType::F32;
|
||||
|
||||
|
@ -1,23 +1,6 @@
|
||||
use super::Config;
|
||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub num_mel_bins: usize, // n_mels
|
||||
pub max_source_positions: usize, // n_audio_ctx
|
||||
pub d_model: usize, // n_audio_state
|
||||
pub encoder_attention_heads: usize, // n_audio_head
|
||||
pub encoder_layers: usize, // n_audio_layer
|
||||
pub vocab_size: usize, // n_vocab
|
||||
pub max_target_positions: usize, // n_text_ctx
|
||||
// pub n_text_state: usize,
|
||||
pub decoder_attention_heads: usize, // n_text_head
|
||||
pub decoder_layers: usize, // n_text_layer
|
||||
pub suppress_tokens: Vec<u32>,
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
|
403
candle-transformers/src/models/whisper/quantized_model.rs
Normal file
403
candle-transformers/src/models/whisper/quantized_model.rs
Normal file
@ -0,0 +1,403 @@
|
||||
use super::Config;
|
||||
use crate::models::{quantized_t5::Embedding, with_tracing::QMatMul};
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, LayerNorm, Module};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Linear {
|
||||
weight: QMatMul,
|
||||
bias: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let x = x.apply(&self.weight)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
|
||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||
Ok(Linear {
|
||||
weight,
|
||||
bias: Some(bias),
|
||||
})
|
||||
}
|
||||
|
||||
fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||
Ok(Linear { weight, bias: None })
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = vb
|
||||
.get((out_channels, in_channels, kernel_size), "weight")?
|
||||
.dequantize(vb.device())?;
|
||||
let bias = vb.get(out_channels, "bias")?.dequantize(vb.device())?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
||||
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
|
||||
Ok(candle_nn::LayerNorm::new(weight, bias, 1e-5))
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
||||
struct MultiHeadAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
out: Linear,
|
||||
n_head: usize,
|
||||
span: tracing::Span,
|
||||
softmax_span: tracing::Span,
|
||||
matmul_span: tracing::Span,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
|
||||
let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
|
||||
let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
|
||||
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
||||
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
||||
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
||||
let out = linear(n_state, n_state, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
n_head,
|
||||
span,
|
||||
softmax_span,
|
||||
matmul_span,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
flush_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let q = self.query.forward(x)?;
|
||||
let (k, v) = match xa {
|
||||
None => {
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
(k, v)
|
||||
}
|
||||
Some(x) => {
|
||||
if flush_cache {
|
||||
self.kv_cache = None;
|
||||
}
|
||||
if let Some((k, v)) = &self.kv_cache {
|
||||
(k.clone(), v.clone())
|
||||
} else {
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
||||
let out = self.out.forward(&wv)?;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||
x.reshape(target_dims)?.transpose(1, 2)
|
||||
}
|
||||
|
||||
fn qkv_attention(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (_, n_ctx, n_state) = q.dims3()?;
|
||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||
let q = (self.reshape_head(q)? * scale)?;
|
||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||
let v = self.reshape_head(v)?.contiguous()?;
|
||||
let mut qk = {
|
||||
let _enter = self.matmul_span.enter();
|
||||
q.matmul(&k)?
|
||||
};
|
||||
if let Some(mask) = mask {
|
||||
let mask = mask.i((0..n_ctx, 0..n_ctx))?;
|
||||
qk = qk.broadcast_add(&mask)?
|
||||
}
|
||||
let w = {
|
||||
let _enter = self.softmax_span.enter();
|
||||
candle_nn::ops::softmax_last_dim(&qk)?
|
||||
};
|
||||
let wv = {
|
||||
let _enter = self.matmul_span.enter();
|
||||
w.matmul(&v)?
|
||||
}
|
||||
.transpose(1, 2)?
|
||||
.flatten_from(2)?;
|
||||
Ok(wv)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
|
||||
struct ResidualAttentionBlock {
|
||||
attn: MultiHeadAttention,
|
||||
attn_ln: LayerNorm,
|
||||
cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
|
||||
mlp_linear1: Linear,
|
||||
mlp_linear2: Linear,
|
||||
mlp_ln: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ResidualAttentionBlock {
|
||||
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
|
||||
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
|
||||
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
|
||||
let cross_attn = if ca {
|
||||
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
|
||||
let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
|
||||
Some((cross_attn, cross_attn_ln))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let n_mlp = n_state * 4;
|
||||
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
|
||||
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
|
||||
let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
attn,
|
||||
attn_ln,
|
||||
cross_attn,
|
||||
mlp_linear1,
|
||||
mlp_linear2,
|
||||
mlp_ln,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
flush_kv_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let attn = self
|
||||
.attn
|
||||
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
|
||||
let mut x = (x + attn)?;
|
||||
if let Some((attn, ln)) = &mut self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
||||
}
|
||||
let mlp = self.mlp_linear2.forward(
|
||||
&self
|
||||
.mlp_linear1
|
||||
.forward(&self.mlp_ln.forward(&x)?)?
|
||||
.gelu()?,
|
||||
)?;
|
||||
x + mlp
|
||||
}
|
||||
}
|
||||
|
||||
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||
let max_timescale = 10000f32;
|
||||
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
|
||||
let inv_timescales: Vec<_> = (0..channels / 2)
|
||||
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
||||
.collect();
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.unsqueeze(1)?;
|
||||
let sh = (length, channels / 2);
|
||||
let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
|
||||
let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
|
||||
Ok(sincos)
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
||||
pub struct AudioEncoder {
|
||||
conv1: Conv1d,
|
||||
conv2: Conv1d,
|
||||
positional_embedding: Tensor,
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln_post: LayerNorm,
|
||||
span: tracing::Span,
|
||||
conv1_span: tracing::Span,
|
||||
conv2_span: tracing::Span,
|
||||
}
|
||||
|
||||
impl AudioEncoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
|
||||
let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
|
||||
let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.encoder_attention_heads;
|
||||
let n_ctx = cfg.max_source_positions;
|
||||
let cfg1 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
|
||||
let blocks = (0..cfg.encoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
positional_embedding,
|
||||
blocks,
|
||||
ln_post,
|
||||
conv1_span,
|
||||
conv2_span,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x = {
|
||||
let _enter = self.conv1_span.enter();
|
||||
self.conv1.forward(x)?.gelu()?
|
||||
};
|
||||
let x = {
|
||||
let _enter = self.conv2_span.enter();
|
||||
self.conv2.forward(&x)?.gelu()?
|
||||
};
|
||||
let x = x.transpose(1, 2)?;
|
||||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, None, None, flush_kv_cache)?
|
||||
}
|
||||
let x = self.ln_post.forward(&x)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
|
||||
pub struct TextDecoder {
|
||||
token_embedding: Embedding,
|
||||
positional_embedding: Tensor,
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln: LayerNorm,
|
||||
mask: Tensor,
|
||||
span: tracing::Span,
|
||||
span_final: tracing::Span,
|
||||
}
|
||||
|
||||
impl TextDecoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
|
||||
let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.decoder_attention_heads;
|
||||
let n_ctx = cfg.max_target_positions;
|
||||
let token_embedding = Embedding::new(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
|
||||
let positional_embedding = vb
|
||||
.get((n_ctx, n_state), "embed_positions.weight")?
|
||||
.dequantize(vb.device())?;
|
||||
let blocks = (0..cfg.decoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
|
||||
let mask: Vec<_> = (0..n_ctx)
|
||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
|
||||
Ok(Self {
|
||||
token_embedding,
|
||||
positional_embedding,
|
||||
blocks,
|
||||
ln,
|
||||
mask,
|
||||
span,
|
||||
span_final,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let last = x.dim(D::Minus1)?;
|
||||
let token_embedding = self.token_embedding.forward(x)?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
||||
}
|
||||
self.ln.forward(&x)
|
||||
}
|
||||
|
||||
pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let b_size = x.dim(0)?;
|
||||
let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
|
||||
let logits = {
|
||||
let _enter = self.span_final.enter();
|
||||
x.matmul(&w.t()?)?
|
||||
};
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
|
||||
pub struct Whisper {
|
||||
pub encoder: AudioEncoder,
|
||||
pub decoder: TextDecoder,
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
impl Whisper {
|
||||
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
|
||||
let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
|
||||
let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user