Preliminary support for whisper v3. (#1294)

* Preliminary support for whisper v3.

* Add the missing files.
This commit is contained in:
Laurent Mazare
2023-11-08 06:42:52 +01:00
committed by GitHub
parent f3a4f3db76
commit 2d28497197
5 changed files with 31 additions and 15 deletions

View File

@ -345,7 +345,7 @@ enum Task {
Translate,
}
#[derive(Clone, Copy, Debug, ValueEnum)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
enum WhichModel {
Tiny,
#[value(name = "tiny.en")]
@ -361,6 +361,7 @@ enum WhichModel {
MediumEn,
Large,
LargeV2,
LargeV3,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
@ -376,6 +377,7 @@ impl WhichModel {
| Self::Medium
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::DistilLargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
false
@ -395,6 +397,7 @@ impl WhichModel {
Self::MediumEn => ("openai/whisper-medium.en", "main"),
Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
}
@ -508,14 +511,18 @@ fn main() -> Result<()> {
repo.get(&format!("model-{ext}-q80.gguf"))?,
)
} else {
(
repo.get("config.json")?,
repo.get("tokenizer.json")?,
repo.get("model.safetensors")?,
)
let config = repo.get("config.json")?;
let tokenizer = if args.model == WhichModel::LargeV3 {
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
} else {
repo.get("tokenizer.json")?
};
let model = repo.get("model.safetensors")?;
(config, tokenizer, model)
};
(config, tokenizer, model, sample)
};
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mel_bytes = include_bytes!("melfilters.bytes");
@ -534,12 +541,15 @@ fn main() -> Result<()> {
.map(|v| *v as f32 / 32768.)
.collect();
println!("pcm data loaded {}", pcm_data.len());
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
let mel = Tensor::from_vec(
mel,
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
&device,
)?;
println!("loaded mel: {:?}", mel.dims());
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = if args.quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;