mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Add a quantized variant of whisper (#1017)
* Add the quantized-whisper model. * Quantized the whisper model. * Adapt the whisper example to handle quantization. * Add the quantized flag. * Load the proper weights.
This commit is contained in:
@ -18,8 +18,48 @@ use rand::{distributions::Distribution, SeedableRng};
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
mod multilingual;
|
mod multilingual;
|
||||||
use candle_transformers::models::whisper::{self as m, audio, model};
|
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||||
use model::{Config, Whisper};
|
|
||||||
|
pub enum Model {
|
||||||
|
Normal(m::model::Whisper),
|
||||||
|
Quantized(m::quantized_model::Whisper),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maybe we should use some traits rather than doing the dispatch for all these.
|
||||||
|
impl Model {
|
||||||
|
pub fn config(&self) -> &Config {
|
||||||
|
match self {
|
||||||
|
Self::Normal(m) => &m.config,
|
||||||
|
Self::Quantized(m) => &m.config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Normal(m) => m.encoder.forward(x, flush),
|
||||||
|
Self::Quantized(m) => m.encoder.forward(x, flush),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decoder_forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
xa: &Tensor,
|
||||||
|
flush: bool,
|
||||||
|
) -> candle::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Normal(m) => m.decoder.forward(x, xa, flush),
|
||||||
|
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Normal(m) => m.decoder.final_linear(x),
|
||||||
|
Self::Quantized(m) => m.decoder.final_linear(x),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -41,7 +81,7 @@ struct Segment {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Decoder {
|
struct Decoder {
|
||||||
model: Whisper,
|
model: Model,
|
||||||
rng: rand::rngs::StdRng,
|
rng: rand::rngs::StdRng,
|
||||||
task: Option<Task>,
|
task: Option<Task>,
|
||||||
timestamps: bool,
|
timestamps: bool,
|
||||||
@ -60,7 +100,7 @@ struct Decoder {
|
|||||||
impl Decoder {
|
impl Decoder {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn new(
|
fn new(
|
||||||
model: Whisper,
|
model: Model,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
seed: u64,
|
seed: u64,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -72,9 +112,9 @@ impl Decoder {
|
|||||||
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
||||||
// Suppress the notimestamps token when in timestamps mode.
|
// Suppress the notimestamps token when in timestamps mode.
|
||||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
||||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
if model.config.suppress_tokens.contains(&i)
|
if model.config().suppress_tokens.contains(&i)
|
||||||
|| timestamps && i == no_timestamps_token
|
|| timestamps && i == no_timestamps_token
|
||||||
{
|
{
|
||||||
f32::NEG_INFINITY
|
f32::NEG_INFINITY
|
||||||
@ -109,11 +149,11 @@ impl Decoder {
|
|||||||
|
|
||||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||||
let model = &mut self.model;
|
let model = &mut self.model;
|
||||||
let audio_features = model.encoder.forward(mel, true)?;
|
let audio_features = model.encoder_forward(mel, true)?;
|
||||||
if self.verbose {
|
if self.verbose {
|
||||||
println!("audio features: {:?}", audio_features.dims());
|
println!("audio features: {:?}", audio_features.dims());
|
||||||
}
|
}
|
||||||
let sample_len = model.config.max_target_positions / 2;
|
let sample_len = model.config().max_target_positions / 2;
|
||||||
let mut sum_logprob = 0f64;
|
let mut sum_logprob = 0f64;
|
||||||
let mut no_speech_prob = f64::NAN;
|
let mut no_speech_prob = f64::NAN;
|
||||||
let mut tokens = vec![self.sot_token];
|
let mut tokens = vec![self.sot_token];
|
||||||
@ -133,12 +173,12 @@ impl Decoder {
|
|||||||
// The model expects a batch dim but this inference loop does not handle
|
// The model expects a batch dim but this inference loop does not handle
|
||||||
// it so we add it at this point.
|
// it so we add it at this point.
|
||||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||||
let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
|
||||||
|
|
||||||
// Extract the no speech probability on the first iteration by looking at the first
|
// Extract the no speech probability on the first iteration by looking at the first
|
||||||
// token logits and the probability for the according token.
|
// token logits and the probability for the according token.
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||||
no_speech_prob = softmax(&logits, 0)?
|
no_speech_prob = softmax(&logits, 0)?
|
||||||
.i(self.no_speech_token as usize)?
|
.i(self.no_speech_token as usize)?
|
||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
@ -146,8 +186,7 @@ impl Decoder {
|
|||||||
|
|
||||||
let (_, seq_len, _) = ys.dims3()?;
|
let (_, seq_len, _) = ys.dims3()?;
|
||||||
let logits = model
|
let logits = model
|
||||||
.decoder
|
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||||
.final_linear(&ys.i((..1, seq_len - 1..))?)?
|
|
||||||
.i(0)?
|
.i(0)?
|
||||||
.i(0)?;
|
.i(0)?;
|
||||||
// TODO: Besides suppress tokens, we should apply the heuristics from
|
// TODO: Besides suppress tokens, we should apply the heuristics from
|
||||||
@ -176,7 +215,7 @@ impl Decoder {
|
|||||||
let prob = softmax(&logits, candle::D::Minus1)?
|
let prob = softmax(&logits, candle::D::Minus1)?
|
||||||
.i(next_token as usize)?
|
.i(next_token as usize)?
|
||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sum_logprob += prob.ln();
|
sum_logprob += prob.ln();
|
||||||
@ -333,6 +372,7 @@ impl WhichModel {
|
|||||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
|
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
||||||
match self {
|
match self {
|
||||||
Self::Tiny => ("openai/whisper-tiny", "main"),
|
Self::Tiny => ("openai/whisper-tiny", "main"),
|
||||||
@ -382,6 +422,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
|
||||||
/// Language.
|
/// Language.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
language: Option<String>,
|
language: Option<String>,
|
||||||
@ -413,10 +456,13 @@ fn main() -> Result<()> {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (default_model, default_revision) = args.model.model_and_revision();
|
let (default_model, default_revision) = if args.quantized {
|
||||||
|
("lmz/candle-whisper", "main")
|
||||||
|
} else {
|
||||||
|
args.model.model_and_revision()
|
||||||
|
};
|
||||||
let default_model = default_model.to_string();
|
let default_model = default_model.to_string();
|
||||||
let default_revision = default_revision.to_string();
|
let default_revision = default_revision.to_string();
|
||||||
let path = std::path::PathBuf::from(default_model.clone());
|
|
||||||
let (model_id, revision) = match (args.model_id, args.revision) {
|
let (model_id, revision) = match (args.model_id, args.revision) {
|
||||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
@ -424,20 +470,7 @@ fn main() -> Result<()> {
|
|||||||
(None, None) => (default_model, default_revision),
|
(None, None) => (default_model, default_revision),
|
||||||
};
|
};
|
||||||
|
|
||||||
let (config_filename, tokenizer_filename, weights_filename, input) = if path.exists() {
|
let (config_filename, tokenizer_filename, weights_filename, input) = {
|
||||||
let mut config_filename = path.clone();
|
|
||||||
config_filename.push("config.json");
|
|
||||||
let mut tokenizer_filename = path.clone();
|
|
||||||
tokenizer_filename.push("tokenizer.json");
|
|
||||||
let mut model_filename = path;
|
|
||||||
model_filename.push("model.safetensors");
|
|
||||||
(
|
|
||||||
config_filename,
|
|
||||||
tokenizer_filename,
|
|
||||||
model_filename,
|
|
||||||
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let dataset = api.dataset("Narsil/candle-examples".to_string());
|
let dataset = api.dataset("Narsil/candle-examples".to_string());
|
||||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
@ -451,12 +484,17 @@ fn main() -> Result<()> {
|
|||||||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||||
dataset.get("samples_jfk.wav")?
|
dataset.get("samples_jfk.wav")?
|
||||||
};
|
};
|
||||||
(
|
let config = if args.quantized {
|
||||||
repo.get("config.json")?,
|
repo.get("config-tiny.json")?
|
||||||
repo.get("tokenizer.json")?,
|
} else {
|
||||||
repo.get("model.safetensors")?,
|
repo.get("config.json")?
|
||||||
sample,
|
};
|
||||||
)
|
let model = if args.quantized {
|
||||||
|
repo.get("model-tiny-q40.gguf")?
|
||||||
|
} else {
|
||||||
|
repo.get("model.safetensors")?
|
||||||
|
};
|
||||||
|
(config, repo.get("tokenizer.json")?, model, sample)
|
||||||
};
|
};
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
@ -481,10 +519,16 @@ fn main() -> Result<()> {
|
|||||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||||
println!("loaded mel: {:?}", mel.dims());
|
println!("loaded mel: {:?}", mel.dims());
|
||||||
|
|
||||||
let vb =
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
|
||||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||||
let mut model = Whisper::load(&vb, config)?;
|
let mut model = if args.quantized {
|
||||||
|
let vb =
|
||||||
|
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
||||||
|
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
||||||
|
} else {
|
||||||
|
let vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
||||||
|
Model::Normal(m::model::Whisper::load(&vb, config)?)
|
||||||
|
};
|
||||||
|
|
||||||
let language_token = match (args.model.is_multilingual(), args.language) {
|
let language_token = match (args.model.is_multilingual(), args.language) {
|
||||||
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),
|
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
use crate::Whisper;
|
|
||||||
use candle::{IndexOp, Result, Tensor, D};
|
use candle::{IndexOp, Result, Tensor, D};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
@ -105,20 +104,28 @@ const LANGUAGES: [(&str, &str); 99] = [
|
|||||||
];
|
];
|
||||||
|
|
||||||
/// Returns the token id for the selected language.
|
/// Returns the token id for the selected language.
|
||||||
pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
pub fn detect_language(
|
||||||
|
model: &mut super::Model,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
mel: &Tensor,
|
||||||
|
) -> Result<u32> {
|
||||||
let (_bsize, _, seq_len) = mel.dims3()?;
|
let (_bsize, _, seq_len) = mel.dims3()?;
|
||||||
let mel = mel.narrow(2, 0, usize::min(seq_len, model.config.max_source_positions))?;
|
let mel = mel.narrow(
|
||||||
|
2,
|
||||||
|
0,
|
||||||
|
usize::min(seq_len, model.config().max_source_positions),
|
||||||
|
)?;
|
||||||
let device = mel.device();
|
let device = mel.device();
|
||||||
let language_token_ids = LANGUAGES
|
let language_token_ids = LANGUAGES
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
|
let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
|
||||||
let audio_features = model.encoder.forward(&mel, true)?;
|
let audio_features = model.encoder_forward(&mel, true)?;
|
||||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||||
let ys = model.decoder.forward(&tokens, &audio_features, true)?;
|
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
|
||||||
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||||
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||||
let probs = probs.to_vec1::<f32>()?;
|
let probs = probs.to_vec1::<f32>()?;
|
||||||
|
@ -1,5 +1,25 @@
|
|||||||
pub mod audio;
|
pub mod audio;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
|
pub mod quantized_model;
|
||||||
|
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
// The names in comments correspond to the original implementation:
|
||||||
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
|
||||||
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub num_mel_bins: usize, // n_mels
|
||||||
|
pub max_source_positions: usize, // n_audio_ctx
|
||||||
|
pub d_model: usize, // n_audio_state
|
||||||
|
pub encoder_attention_heads: usize, // n_audio_head
|
||||||
|
pub encoder_layers: usize, // n_audio_layer
|
||||||
|
pub vocab_size: usize, // n_vocab
|
||||||
|
pub max_target_positions: usize, // n_text_ctx
|
||||||
|
// pub n_text_state: usize,
|
||||||
|
pub decoder_attention_heads: usize, // n_text_head
|
||||||
|
pub decoder_layers: usize, // n_text_layer
|
||||||
|
pub suppress_tokens: Vec<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
pub const DTYPE: candle::DType = candle::DType::F32;
|
pub const DTYPE: candle::DType = candle::DType::F32;
|
||||||
|
|
||||||
|
@ -1,23 +1,6 @@
|
|||||||
|
use super::Config;
|
||||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
// The names in comments correspond to the original implementation:
|
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
|
||||||
pub struct Config {
|
|
||||||
pub num_mel_bins: usize, // n_mels
|
|
||||||
pub max_source_positions: usize, // n_audio_ctx
|
|
||||||
pub d_model: usize, // n_audio_state
|
|
||||||
pub encoder_attention_heads: usize, // n_audio_head
|
|
||||||
pub encoder_layers: usize, // n_audio_layer
|
|
||||||
pub vocab_size: usize, // n_vocab
|
|
||||||
pub max_target_positions: usize, // n_text_ctx
|
|
||||||
// pub n_text_state: usize,
|
|
||||||
pub decoder_attention_heads: usize, // n_text_head
|
|
||||||
pub decoder_layers: usize, // n_text_layer
|
|
||||||
pub suppress_tokens: Vec<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||||
|
403
candle-transformers/src/models/whisper/quantized_model.rs
Normal file
403
candle-transformers/src/models/whisper/quantized_model.rs
Normal file
@ -0,0 +1,403 @@
|
|||||||
|
use super::Config;
|
||||||
|
use crate::models::{quantized_t5::Embedding, with_tracing::QMatMul};
|
||||||
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
|
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||||
|
use candle_nn::{Conv1d, Conv1dConfig, LayerNorm, Module};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Linear {
|
||||||
|
weight: QMatMul,
|
||||||
|
bias: Option<Tensor>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Linear {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
let x = x.apply(&self.weight)?;
|
||||||
|
match &self.bias {
|
||||||
|
None => Ok(x),
|
||||||
|
Some(bias) => x.broadcast_add(bias),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
|
||||||
|
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||||
|
Ok(Linear {
|
||||||
|
weight,
|
||||||
|
bias: Some(bias),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||||
|
Ok(Linear { weight, bias: None })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv1d(
|
||||||
|
in_channels: usize,
|
||||||
|
out_channels: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
config: Conv1dConfig,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Conv1d> {
|
||||||
|
let weight = vb
|
||||||
|
.get((out_channels, in_channels, kernel_size), "weight")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let bias = vb.get(out_channels, "bias")?.dequantize(vb.device())?;
|
||||||
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn layer_norm(size: usize, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
|
||||||
|
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
||||||
|
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
|
||||||
|
Ok(candle_nn::LayerNorm::new(weight, bias, 1e-5))
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
||||||
|
struct MultiHeadAttention {
|
||||||
|
query: Linear,
|
||||||
|
key: Linear,
|
||||||
|
value: Linear,
|
||||||
|
out: Linear,
|
||||||
|
n_head: usize,
|
||||||
|
span: tracing::Span,
|
||||||
|
softmax_span: tracing::Span,
|
||||||
|
matmul_span: tracing::Span,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MultiHeadAttention {
|
||||||
|
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
|
||||||
|
let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
|
||||||
|
let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
|
||||||
|
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
||||||
|
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
||||||
|
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
||||||
|
let out = linear(n_state, n_state, vb.pp("out_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
out,
|
||||||
|
n_head,
|
||||||
|
span,
|
||||||
|
softmax_span,
|
||||||
|
matmul_span,
|
||||||
|
kv_cache: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
xa: Option<&Tensor>,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
|
flush_cache: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let q = self.query.forward(x)?;
|
||||||
|
let (k, v) = match xa {
|
||||||
|
None => {
|
||||||
|
let k = self.key.forward(x)?;
|
||||||
|
let v = self.value.forward(x)?;
|
||||||
|
(k, v)
|
||||||
|
}
|
||||||
|
Some(x) => {
|
||||||
|
if flush_cache {
|
||||||
|
self.kv_cache = None;
|
||||||
|
}
|
||||||
|
if let Some((k, v)) = &self.kv_cache {
|
||||||
|
(k.clone(), v.clone())
|
||||||
|
} else {
|
||||||
|
let k = self.key.forward(x)?;
|
||||||
|
let v = self.value.forward(x)?;
|
||||||
|
self.kv_cache = Some((k.clone(), v.clone()));
|
||||||
|
(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
||||||
|
let out = self.out.forward(&wv)?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
||||||
|
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||||
|
x.reshape(target_dims)?.transpose(1, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn qkv_attention(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (_, n_ctx, n_state) = q.dims3()?;
|
||||||
|
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||||
|
let q = (self.reshape_head(q)? * scale)?;
|
||||||
|
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||||
|
let v = self.reshape_head(v)?.contiguous()?;
|
||||||
|
let mut qk = {
|
||||||
|
let _enter = self.matmul_span.enter();
|
||||||
|
q.matmul(&k)?
|
||||||
|
};
|
||||||
|
if let Some(mask) = mask {
|
||||||
|
let mask = mask.i((0..n_ctx, 0..n_ctx))?;
|
||||||
|
qk = qk.broadcast_add(&mask)?
|
||||||
|
}
|
||||||
|
let w = {
|
||||||
|
let _enter = self.softmax_span.enter();
|
||||||
|
candle_nn::ops::softmax_last_dim(&qk)?
|
||||||
|
};
|
||||||
|
let wv = {
|
||||||
|
let _enter = self.matmul_span.enter();
|
||||||
|
w.matmul(&v)?
|
||||||
|
}
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.flatten_from(2)?;
|
||||||
|
Ok(wv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
|
||||||
|
struct ResidualAttentionBlock {
|
||||||
|
attn: MultiHeadAttention,
|
||||||
|
attn_ln: LayerNorm,
|
||||||
|
cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
|
||||||
|
mlp_linear1: Linear,
|
||||||
|
mlp_linear2: Linear,
|
||||||
|
mlp_ln: LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResidualAttentionBlock {
|
||||||
|
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
|
||||||
|
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
|
||||||
|
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
|
||||||
|
let cross_attn = if ca {
|
||||||
|
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
|
||||||
|
let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
|
||||||
|
Some((cross_attn, cross_attn_ln))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let n_mlp = n_state * 4;
|
||||||
|
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
|
||||||
|
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
|
||||||
|
let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
attn,
|
||||||
|
attn_ln,
|
||||||
|
cross_attn,
|
||||||
|
mlp_linear1,
|
||||||
|
mlp_linear2,
|
||||||
|
mlp_ln,
|
||||||
|
span,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
xa: Option<&Tensor>,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
|
flush_kv_cache: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let attn = self
|
||||||
|
.attn
|
||||||
|
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
|
||||||
|
let mut x = (x + attn)?;
|
||||||
|
if let Some((attn, ln)) = &mut self.cross_attn {
|
||||||
|
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
||||||
|
}
|
||||||
|
let mlp = self.mlp_linear2.forward(
|
||||||
|
&self
|
||||||
|
.mlp_linear1
|
||||||
|
.forward(&self.mlp_ln.forward(&x)?)?
|
||||||
|
.gelu()?,
|
||||||
|
)?;
|
||||||
|
x + mlp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||||
|
let max_timescale = 10000f32;
|
||||||
|
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
|
||||||
|
let inv_timescales: Vec<_> = (0..channels / 2)
|
||||||
|
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
||||||
|
.collect();
|
||||||
|
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||||
|
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
|
||||||
|
.to_dtype(candle::DType::F32)?
|
||||||
|
.unsqueeze(1)?;
|
||||||
|
let sh = (length, channels / 2);
|
||||||
|
let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
|
||||||
|
let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
|
||||||
|
Ok(sincos)
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
||||||
|
pub struct AudioEncoder {
|
||||||
|
conv1: Conv1d,
|
||||||
|
conv2: Conv1d,
|
||||||
|
positional_embedding: Tensor,
|
||||||
|
blocks: Vec<ResidualAttentionBlock>,
|
||||||
|
ln_post: LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
|
conv1_span: tracing::Span,
|
||||||
|
conv2_span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AudioEncoder {
|
||||||
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
|
||||||
|
let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
|
||||||
|
let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
|
||||||
|
let n_state = cfg.d_model;
|
||||||
|
let n_head = cfg.encoder_attention_heads;
|
||||||
|
let n_ctx = cfg.max_source_positions;
|
||||||
|
let cfg1 = Conv1dConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 1,
|
||||||
|
groups: 1,
|
||||||
|
dilation: 1,
|
||||||
|
};
|
||||||
|
let cfg2 = Conv1dConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 2,
|
||||||
|
groups: 1,
|
||||||
|
dilation: 1,
|
||||||
|
};
|
||||||
|
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||||
|
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||||
|
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
|
||||||
|
let blocks = (0..cfg.encoder_layers)
|
||||||
|
.map(|i| {
|
||||||
|
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
conv1,
|
||||||
|
conv2,
|
||||||
|
positional_embedding,
|
||||||
|
blocks,
|
||||||
|
ln_post,
|
||||||
|
conv1_span,
|
||||||
|
conv2_span,
|
||||||
|
span,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let x = {
|
||||||
|
let _enter = self.conv1_span.enter();
|
||||||
|
self.conv1.forward(x)?.gelu()?
|
||||||
|
};
|
||||||
|
let x = {
|
||||||
|
let _enter = self.conv2_span.enter();
|
||||||
|
self.conv2.forward(&x)?.gelu()?
|
||||||
|
};
|
||||||
|
let x = x.transpose(1, 2)?;
|
||||||
|
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||||
|
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||||
|
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||||
|
for block in self.blocks.iter_mut() {
|
||||||
|
x = block.forward(&x, None, None, flush_kv_cache)?
|
||||||
|
}
|
||||||
|
let x = self.ln_post.forward(&x)?;
|
||||||
|
Ok(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
|
||||||
|
pub struct TextDecoder {
|
||||||
|
token_embedding: Embedding,
|
||||||
|
positional_embedding: Tensor,
|
||||||
|
blocks: Vec<ResidualAttentionBlock>,
|
||||||
|
ln: LayerNorm,
|
||||||
|
mask: Tensor,
|
||||||
|
span: tracing::Span,
|
||||||
|
span_final: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextDecoder {
|
||||||
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
|
||||||
|
let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
|
||||||
|
let n_state = cfg.d_model;
|
||||||
|
let n_head = cfg.decoder_attention_heads;
|
||||||
|
let n_ctx = cfg.max_target_positions;
|
||||||
|
let token_embedding = Embedding::new(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
|
||||||
|
let positional_embedding = vb
|
||||||
|
.get((n_ctx, n_state), "embed_positions.weight")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let blocks = (0..cfg.decoder_layers)
|
||||||
|
.map(|i| {
|
||||||
|
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
|
||||||
|
let mask: Vec<_> = (0..n_ctx)
|
||||||
|
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||||
|
.collect();
|
||||||
|
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
|
||||||
|
Ok(Self {
|
||||||
|
token_embedding,
|
||||||
|
positional_embedding,
|
||||||
|
blocks,
|
||||||
|
ln,
|
||||||
|
mask,
|
||||||
|
span,
|
||||||
|
span_final,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let last = x.dim(D::Minus1)?;
|
||||||
|
let token_embedding = self.token_embedding.forward(x)?;
|
||||||
|
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||||
|
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||||
|
for block in self.blocks.iter_mut() {
|
||||||
|
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
||||||
|
}
|
||||||
|
self.ln.forward(&x)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let b_size = x.dim(0)?;
|
||||||
|
let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
|
||||||
|
let logits = {
|
||||||
|
let _enter = self.span_final.enter();
|
||||||
|
x.matmul(&w.t()?)?
|
||||||
|
};
|
||||||
|
Ok(logits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
|
||||||
|
pub struct Whisper {
|
||||||
|
pub encoder: AudioEncoder,
|
||||||
|
pub decoder: TextDecoder,
|
||||||
|
pub config: Config,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Whisper {
|
||||||
|
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
|
||||||
|
let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
|
||||||
|
let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
|
||||||
|
Ok(Self {
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
config,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user