From 91dbf907d3ee45dd4777efa82c1f431907ce8125 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 11 Aug 2023 18:33:55 +0200 Subject: [PATCH] Add more whisper variants. (#413) --- candle-examples/examples/whisper/main.rs | 28 +++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index dfe7a27f..d5f91053 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -11,7 +11,7 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; use candle_nn::{ops::softmax, VarBuilder}; -use clap::Parser; +use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; @@ -216,6 +216,23 @@ impl Decoder { } } +#[derive(Clone, Copy, Debug, ValueEnum)] +enum WhichModel { + Tiny, + Small, + Medium, +} + +impl WhichModel { + fn model_and_revision(&self) -> (&'static str, &'static str) { + match self { + Self::Tiny => ("openai/whisper-tiny.en", "refs/pr/15"), + Self::Small => ("openai/whisper-small.en", "refs/pr/10"), + Self::Medium => ("openai/whisper-medium.en", "refs/pr/11"), + } + } +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -231,6 +248,10 @@ struct Args { #[arg(long)] revision: Option, + /// The model to be used, can be tiny, small, medium. + #[arg(long, default_value = "tiny")] + model: WhichModel, + /// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively /// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following /// repo: https://huggingface.co/datasets/Narsil/candle_demo/ @@ -260,9 +281,10 @@ fn main() -> Result<()> { None }; let device = candle_examples::device(args.cpu)?; - let default_model = "openai/whisper-tiny.en".to_string(); + let (default_model, default_revision) = args.model.model_and_revision(); + let default_model = default_model.to_string(); + let default_revision = default_revision.to_string(); let path = std::path::PathBuf::from(default_model.clone()); - let default_revision = "refs/pr/15".to_string(); let (model_id, revision) = match (args.model_id, args.revision) { (Some(model_id), Some(revision)) => (model_id, revision), (Some(model_id), None) => (model_id, "main".to_string()),