From a0908d212c1e874d21d87f1587d8ff2394158740 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 12 Aug 2023 18:08:40 +0200 Subject: [PATCH] Add a -language argument. (#425) --- candle-examples/examples/whisper/main.rs | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 61db40bc..1c24de60 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -1,4 +1,4 @@ -// https://github.com/openai/whisper/blob/main/whisper/model.py +// https://github.com/openai/whisper/blob/main/whisper/model.py/rgs // TODO: // - kv-cache support? // - Batch size greater than 1. @@ -301,6 +301,10 @@ struct Args { /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, + + /// Language. + #[arg(long)] + language: Option, } fn main() -> Result<()> { @@ -391,10 +395,16 @@ fn main() -> Result<()> { let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let model = Whisper::load(&vb, config)?; - let language_token = if args.model.is_multilingual() { - Some(multilingual::detect_language(&model, &tokenizer, &mel)?) - } else { - None + let language_token = match (args.model.is_multilingual(), args.language) { + (true, None) => Some(multilingual::detect_language(&model, &tokenizer, &mel)?), + (false, None) => None, + (true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) { + Ok(token_id) => Some(token_id), + Err(_) => anyhow::bail!("language {language} is not supported"), + }, + (false, Some(_)) => { + anyhow::bail!("a language cannot be set for non-multilingual models") + } }; let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?; dc.run(&mel)?;