mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a -language argument. (#425)
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
// https://github.com/openai/whisper/blob/main/whisper/model.py
|
||||
// https://github.com/openai/whisper/blob/main/whisper/model.py/rgs
|
||||
// TODO:
|
||||
// - kv-cache support?
|
||||
// - Batch size greater than 1.
|
||||
@ -301,6 +301,10 @@ struct Args {
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Language.
|
||||
#[arg(long)]
|
||||
language: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -391,10 +395,16 @@ fn main() -> Result<()> {
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let model = Whisper::load(&vb, config)?;
|
||||
|
||||
let language_token = if args.model.is_multilingual() {
|
||||
Some(multilingual::detect_language(&model, &tokenizer, &mel)?)
|
||||
} else {
|
||||
None
|
||||
let language_token = match (args.model.is_multilingual(), args.language) {
|
||||
(true, None) => Some(multilingual::detect_language(&model, &tokenizer, &mel)?),
|
||||
(false, None) => None,
|
||||
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
|
||||
Ok(token_id) => Some(token_id),
|
||||
Err(_) => anyhow::bail!("language {language} is not supported"),
|
||||
},
|
||||
(false, Some(_)) => {
|
||||
anyhow::bail!("a language cannot be set for non-multilingual models")
|
||||
}
|
||||
};
|
||||
let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?;
|
||||
dc.run(&mel)?;
|
||||
|
Reference in New Issue
Block a user