mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add more whisper variants. (#413)
This commit is contained in:
@ -11,7 +11,7 @@ extern crate intel_mkl_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
@ -216,6 +216,23 @@ impl Decoder {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum WhichModel {
|
||||
Tiny,
|
||||
Small,
|
||||
Medium,
|
||||
}
|
||||
|
||||
impl WhichModel {
|
||||
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
||||
match self {
|
||||
Self::Tiny => ("openai/whisper-tiny.en", "refs/pr/15"),
|
||||
Self::Small => ("openai/whisper-small.en", "refs/pr/10"),
|
||||
Self::Medium => ("openai/whisper-medium.en", "refs/pr/11"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -231,6 +248,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// The model to be used, can be tiny, small, medium.
|
||||
#[arg(long, default_value = "tiny")]
|
||||
model: WhichModel,
|
||||
|
||||
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
|
||||
/// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following
|
||||
/// repo: https://huggingface.co/datasets/Narsil/candle_demo/
|
||||
@ -260,9 +281,10 @@ fn main() -> Result<()> {
|
||||
None
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let default_model = "openai/whisper-tiny.en".to_string();
|
||||
let (default_model, default_revision) = args.model.model_and_revision();
|
||||
let default_model = default_model.to_string();
|
||||
let default_revision = default_revision.to_string();
|
||||
let path = std::path::PathBuf::from(default_model.clone());
|
||||
let default_revision = "refs/pr/15".to_string();
|
||||
let (model_id, revision) = match (args.model_id, args.revision) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
|
Reference in New Issue
Block a user