Add more whisper variants. (#413)

This commit is contained in:
Laurent Mazare
2023-08-11 18:33:55 +02:00
committed by GitHub
parent e12372021b
commit 91dbf907d3

View File

@ -11,7 +11,7 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
@ -216,6 +216,23 @@ impl Decoder {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel {
Tiny,
Small,
Medium,
}
impl WhichModel {
fn model_and_revision(&self) -> (&'static str, &'static str) {
match self {
Self::Tiny => ("openai/whisper-tiny.en", "refs/pr/15"),
Self::Small => ("openai/whisper-small.en", "refs/pr/10"),
Self::Medium => ("openai/whisper-medium.en", "refs/pr/11"),
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -231,6 +248,10 @@ struct Args {
#[arg(long)]
revision: Option<String>,
/// The model to be used, can be tiny, small, medium.
#[arg(long, default_value = "tiny")]
model: WhichModel,
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
/// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following
/// repo: https://huggingface.co/datasets/Narsil/candle_demo/
@ -260,9 +281,10 @@ fn main() -> Result<()> {
None
};
let device = candle_examples::device(args.cpu)?;
let default_model = "openai/whisper-tiny.en".to_string();
let (default_model, default_revision) = args.model.model_and_revision();
let default_model = default_model.to_string();
let default_revision = default_revision.to_string();
let path = std::path::PathBuf::from(default_model.clone());
let default_revision = "refs/pr/15".to_string();
let (model_id, revision) = match (args.model_id, args.revision) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),