mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Preliminary support for whisper v3. (#1294)
* Preliminary support for whisper v3. * Add the missing files.
This commit is contained in:
@ -345,7 +345,7 @@ enum Task {
|
||||
Translate,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
|
||||
enum WhichModel {
|
||||
Tiny,
|
||||
#[value(name = "tiny.en")]
|
||||
@ -361,6 +361,7 @@ enum WhichModel {
|
||||
MediumEn,
|
||||
Large,
|
||||
LargeV2,
|
||||
LargeV3,
|
||||
#[value(name = "distil-medium.en")]
|
||||
DistilMediumEn,
|
||||
#[value(name = "distil-large-v2")]
|
||||
@ -376,6 +377,7 @@ impl WhichModel {
|
||||
| Self::Medium
|
||||
| Self::Large
|
||||
| Self::LargeV2
|
||||
| Self::LargeV3
|
||||
| Self::DistilLargeV2 => true,
|
||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
|
||||
false
|
||||
@ -395,6 +397,7 @@ impl WhichModel {
|
||||
Self::MediumEn => ("openai/whisper-medium.en", "main"),
|
||||
Self::Large => ("openai/whisper-large", "refs/pr/36"),
|
||||
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
|
||||
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
|
||||
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
|
||||
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
|
||||
}
|
||||
@ -508,14 +511,18 @@ fn main() -> Result<()> {
|
||||
repo.get(&format!("model-{ext}-q80.gguf"))?,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
repo.get("config.json")?,
|
||||
repo.get("tokenizer.json")?,
|
||||
repo.get("model.safetensors")?,
|
||||
)
|
||||
let config = repo.get("config.json")?;
|
||||
let tokenizer = if args.model == WhichModel::LargeV3 {
|
||||
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
|
||||
} else {
|
||||
repo.get("tokenizer.json")?
|
||||
};
|
||||
let model = repo.get("model.safetensors")?;
|
||||
(config, tokenizer, model)
|
||||
};
|
||||
(config, tokenizer, model, sample)
|
||||
};
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let mel_bytes = include_bytes!("melfilters.bytes");
|
||||
@ -534,12 +541,15 @@ fn main() -> Result<()> {
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
println!("pcm data loaded {}", pcm_data.len());
|
||||
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
|
||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
||||
let mel_len = mel.len();
|
||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||
let mel = Tensor::from_vec(
|
||||
mel,
|
||||
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
|
||||
&device,
|
||||
)?;
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let mut model = if args.quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
||||
|
@ -198,13 +198,17 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
mel
|
||||
}
|
||||
|
||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> {
|
||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
cfg: &super::Config,
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
) -> Vec<T> {
|
||||
log_mel_spectrogram_(
|
||||
samples,
|
||||
filters,
|
||||
super::N_FFT,
|
||||
super::HOP_LENGTH,
|
||||
super::N_MELS,
|
||||
cfg.num_mel_bins,
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ pub struct Config {
|
||||
// pub n_text_state: usize,
|
||||
pub decoder_attention_heads: usize, // n_text_head
|
||||
pub decoder_layers: usize, // n_text_layer
|
||||
#[serde(default)]
|
||||
pub suppress_tokens: Vec<u32>,
|
||||
}
|
||||
|
||||
@ -26,7 +27,6 @@ pub const DTYPE: candle::DType = candle::DType::F32;
|
||||
// Audio parameters.
|
||||
pub const SAMPLE_RATE: usize = 16000;
|
||||
pub const N_FFT: usize = 400;
|
||||
pub const N_MELS: usize = 80;
|
||||
pub const HOP_LENGTH: usize = 160;
|
||||
pub const CHUNK_LENGTH: usize = 30;
|
||||
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
||||
|
@ -200,6 +200,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
}
|
||||
|
||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
cfg: &worker::m::Config,
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
) -> anyhow::Result<Vec<T>> {
|
||||
@ -208,7 +209,7 @@ pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
filters,
|
||||
worker::m::N_FFT,
|
||||
worker::m::HOP_LENGTH,
|
||||
worker::m::N_MELS,
|
||||
cfg.num_mel_bins,
|
||||
false,
|
||||
);
|
||||
Ok(mel)
|
||||
|
@ -349,9 +349,10 @@ impl Decoder {
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
console_log!("pcm data loaded {}", pcm_data.len());
|
||||
let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?;
|
||||
let mel = crate::audio::pcm_to_mel(self.model.config(), &pcm_data, &self.mel_filters)?;
|
||||
let mel_len = mel.len();
|
||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||
let n_mels = self.model.config().num_mel_bins;
|
||||
let mel = Tensor::from_vec(mel, (1, n_mels, mel_len / n_mels), &device)?;
|
||||
console_log!("loaded mel: {:?}", mel.dims());
|
||||
let segments = self.run(&mel)?;
|
||||
Ok(segments)
|
||||
|
Reference in New Issue
Block a user