From 8ad12a0e81849d0bdb2e2b59d0f18e2b54174cd0 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 29 Mar 2024 18:09:29 +0100 Subject: [PATCH] Add some examples using the MT5 variants. (#1963) --- candle-examples/examples/t5/main.rs | 52 ++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index be6bc6b5..34ae0ead 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -23,6 +23,9 @@ enum Which { T5Base, T5Small, T5_3B, + Mt5Base, + Mt5Small, + Mt5Large, } #[derive(Parser, Debug, Clone)] @@ -43,6 +46,15 @@ struct Args { #[arg(long)] revision: Option, + #[arg(long)] + model_file: Option, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + config_file: Option, + /// Enable decoding. #[arg(long)] decode: bool, @@ -97,6 +109,9 @@ impl T5ModelBuilder { Which::T5Base => ("t5-base", "main"), Which::T5Small => ("t5-small", "refs/pr/15"), Which::T5_3B => ("t5-3b", "main"), + Which::Mt5Base => ("google/mt5-base", "refs/pr/5"), + Which::Mt5Small => ("google/mt5-small", "refs/pr/6"), + Which::Mt5Large => ("google/mt5-large", "refs/pr/2"), }; let default_model = default_model.to_string(); let default_revision = default_revision.to_string(); @@ -109,14 +124,35 @@ impl T5ModelBuilder { let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision); let api = Api::new()?; - let api = api.repo(repo); - let config_filename = api.get("config.json")?; - let tokenizer_filename = api.get("tokenizer.json")?; - let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" - { - candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? - } else { - vec![api.get("model.safetensors")?] + let repo = api.repo(repo); + let config_filename = match &args.config_file { + None => repo.get("config.json")?, + Some(f) => f.into(), + }; + let tokenizer_filename = match &args.tokenizer_file { + None => match args.which { + Which::Mt5Base => api + .model("lmz/mt5-tokenizers".into()) + .get("mt5-base.tokenizer.json")?, + Which::Mt5Small => api + .model("lmz/mt5-tokenizers".into()) + .get("mt5-small.tokenizer.json")?, + Which::Mt5Large => api + .model("lmz/mt5-tokenizers".into()) + .get("mt5-large.tokenizer.json")?, + _ => repo.get("tokenizer.json")?, + }, + Some(f) => f.into(), + }; + let weights_filename = match &args.model_file { + Some(f) => f.split(',').map(|v| v.into()).collect::>(), + None => { + if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" { + candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? + } else { + vec![repo.get("model.safetensors")?] + } + } }; let config = std::fs::read_to_string(config_filename)?; let mut config: t5::Config = serde_json::from_str(&config)?;