mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add a quantized variant of whisper (#1017)
* Add the quantized-whisper model. * Quantized the whisper model. * Adapt the whisper example to handle quantization. * Add the quantized flag. * Load the proper weights.
This commit is contained in:
@ -18,8 +18,48 @@ use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod multilingual;
|
||||
use candle_transformers::models::whisper::{self as m, audio, model};
|
||||
use model::{Config, Whisper};
|
||||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||
|
||||
pub enum Model {
|
||||
Normal(m::model::Whisper),
|
||||
Quantized(m::quantized_model::Whisper),
|
||||
}
|
||||
|
||||
// Maybe we should use some traits rather than doing the dispatch for all these.
|
||||
impl Model {
|
||||
pub fn config(&self) -> &Config {
|
||||
match self {
|
||||
Self::Normal(m) => &m.config,
|
||||
Self::Quantized(m) => &m.config,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.encoder.forward(x, flush),
|
||||
Self::Quantized(m) => m.encoder.forward(x, flush),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decoder_forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: &Tensor,
|
||||
flush: bool,
|
||||
) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.decoder.forward(x, xa, flush),
|
||||
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.decoder.final_linear(x),
|
||||
Self::Quantized(m) => m.decoder.final_linear(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
@ -41,7 +81,7 @@ struct Segment {
|
||||
}
|
||||
|
||||
struct Decoder {
|
||||
model: Whisper,
|
||||
model: Model,
|
||||
rng: rand::rngs::StdRng,
|
||||
task: Option<Task>,
|
||||
timestamps: bool,
|
||||
@ -60,7 +100,7 @@ struct Decoder {
|
||||
impl Decoder {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Whisper,
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
device: &Device,
|
||||
@ -72,9 +112,9 @@ impl Decoder {
|
||||
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
||||
// Suppress the notimestamps token when in timestamps mode.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
|
||||
.map(|i| {
|
||||
if model.config.suppress_tokens.contains(&i)
|
||||
if model.config().suppress_tokens.contains(&i)
|
||||
|| timestamps && i == no_timestamps_token
|
||||
{
|
||||
f32::NEG_INFINITY
|
||||
@ -109,11 +149,11 @@ impl Decoder {
|
||||
|
||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||
let model = &mut self.model;
|
||||
let audio_features = model.encoder.forward(mel, true)?;
|
||||
let audio_features = model.encoder_forward(mel, true)?;
|
||||
if self.verbose {
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
}
|
||||
let sample_len = model.config.max_target_positions / 2;
|
||||
let sample_len = model.config().max_target_positions / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![self.sot_token];
|
||||
@ -133,12 +173,12 @@ impl Decoder {
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
|
||||
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
no_speech_prob = softmax(&logits, 0)?
|
||||
.i(self.no_speech_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
@ -146,8 +186,7 @@ impl Decoder {
|
||||
|
||||
let (_, seq_len, _) = ys.dims3()?;
|
||||
let logits = model
|
||||
.decoder
|
||||
.final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
// TODO: Besides suppress tokens, we should apply the heuristics from
|
||||
@ -176,7 +215,7 @@ impl Decoder {
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.i(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
||||
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
|
||||
break;
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
@ -333,6 +372,7 @@ impl WhichModel {
|
||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
||||
match self {
|
||||
Self::Tiny => ("openai/whisper-tiny", "main"),
|
||||
@ -382,6 +422,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Language.
|
||||
#[arg(long)]
|
||||
language: Option<String>,
|
||||
@ -413,10 +456,13 @@ fn main() -> Result<()> {
|
||||
None
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (default_model, default_revision) = args.model.model_and_revision();
|
||||
let (default_model, default_revision) = if args.quantized {
|
||||
("lmz/candle-whisper", "main")
|
||||
} else {
|
||||
args.model.model_and_revision()
|
||||
};
|
||||
let default_model = default_model.to_string();
|
||||
let default_revision = default_revision.to_string();
|
||||
let path = std::path::PathBuf::from(default_model.clone());
|
||||
let (model_id, revision) = match (args.model_id, args.revision) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
@ -424,20 +470,7 @@ fn main() -> Result<()> {
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let (config_filename, tokenizer_filename, weights_filename, input) = if path.exists() {
|
||||
let mut config_filename = path.clone();
|
||||
config_filename.push("config.json");
|
||||
let mut tokenizer_filename = path.clone();
|
||||
tokenizer_filename.push("tokenizer.json");
|
||||
let mut model_filename = path;
|
||||
model_filename.push("model.safetensors");
|
||||
(
|
||||
config_filename,
|
||||
tokenizer_filename,
|
||||
model_filename,
|
||||
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
|
||||
)
|
||||
} else {
|
||||
let (config_filename, tokenizer_filename, weights_filename, input) = {
|
||||
let api = Api::new()?;
|
||||
let dataset = api.dataset("Narsil/candle-examples".to_string());
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
@ -451,12 +484,17 @@ fn main() -> Result<()> {
|
||||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||
dataset.get("samples_jfk.wav")?
|
||||
};
|
||||
(
|
||||
repo.get("config.json")?,
|
||||
repo.get("tokenizer.json")?,
|
||||
repo.get("model.safetensors")?,
|
||||
sample,
|
||||
)
|
||||
let config = if args.quantized {
|
||||
repo.get("config-tiny.json")?
|
||||
} else {
|
||||
repo.get("config.json")?
|
||||
};
|
||||
let model = if args.quantized {
|
||||
repo.get("model-tiny-q40.gguf")?
|
||||
} else {
|
||||
repo.get("model.safetensors")?
|
||||
};
|
||||
(config, repo.get("tokenizer.json")?, model, sample)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
@ -481,10 +519,16 @@ fn main() -> Result<()> {
|
||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let mut model = Whisper::load(&vb, config)?;
|
||||
let mut model = if args.quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
||||
} else {
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
||||
Model::Normal(m::model::Whisper::load(&vb, config)?)
|
||||
};
|
||||
|
||||
let language_token = match (args.model.is_multilingual(), args.language) {
|
||||
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),
|
||||
|
Reference in New Issue
Block a user