Add a -language argument. (#425)

This commit is contained in:
Laurent Mazare
2023-08-12 18:08:40 +02:00
committed by GitHub
parent 972078e1ae
commit a0908d212c

View File

@ -1,4 +1,4 @@
// https://github.com/openai/whisper/blob/main/whisper/model.py
// https://github.com/openai/whisper/blob/main/whisper/model.py/rgs
// TODO:
// - kv-cache support?
// - Batch size greater than 1.
@ -301,6 +301,10 @@ struct Args {
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Language.
#[arg(long)]
language: Option<String>,
}
fn main() -> Result<()> {
@ -391,10 +395,16 @@ fn main() -> Result<()> {
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let model = Whisper::load(&vb, config)?;
let language_token = if args.model.is_multilingual() {
Some(multilingual::detect_language(&model, &tokenizer, &mel)?)
} else {
None
let language_token = match (args.model.is_multilingual(), args.language) {
(true, None) => Some(multilingual::detect_language(&model, &tokenizer, &mel)?),
(false, None) => None,
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
Ok(token_id) => Some(token_id),
Err(_) => anyhow::bail!("language {language} is not supported"),
},
(false, Some(_)) => {
anyhow::bail!("a language cannot be set for non-multilingual models")
}
};
let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?;
dc.run(&mel)?;