Add some examples using the MT5 variants. (#1963)

This commit is contained in:
Laurent Mazare
2024-03-29 18:09:29 +01:00
committed by GitHub
parent eb1b27abcd
commit 8ad12a0e81

View File

@ -23,6 +23,9 @@ enum Which {
T5Base, T5Base,
T5Small, T5Small,
T5_3B, T5_3B,
Mt5Base,
Mt5Small,
Mt5Large,
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
@ -43,6 +46,15 @@ struct Args {
#[arg(long)] #[arg(long)]
revision: Option<String>, revision: Option<String>,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// Enable decoding. /// Enable decoding.
#[arg(long)] #[arg(long)]
decode: bool, decode: bool,
@ -97,6 +109,9 @@ impl T5ModelBuilder {
Which::T5Base => ("t5-base", "main"), Which::T5Base => ("t5-base", "main"),
Which::T5Small => ("t5-small", "refs/pr/15"), Which::T5Small => ("t5-small", "refs/pr/15"),
Which::T5_3B => ("t5-3b", "main"), Which::T5_3B => ("t5-3b", "main"),
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
Which::Mt5Large => ("google/mt5-large", "refs/pr/2"),
}; };
let default_model = default_model.to_string(); let default_model = default_model.to_string();
let default_revision = default_revision.to_string(); let default_revision = default_revision.to_string();
@ -109,14 +124,35 @@ impl T5ModelBuilder {
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision); let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
let api = Api::new()?; let api = Api::new()?;
let api = api.repo(repo); let repo = api.repo(repo);
let config_filename = api.get("config.json")?; let config_filename = match &args.config_file {
let tokenizer_filename = api.get("tokenizer.json")?; None => repo.get("config.json")?,
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" Some(f) => f.into(),
{ };
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? let tokenizer_filename = match &args.tokenizer_file {
} else { None => match args.which {
vec![api.get("model.safetensors")?] Which::Mt5Base => api
.model("lmz/mt5-tokenizers".into())
.get("mt5-base.tokenizer.json")?,
Which::Mt5Small => api
.model("lmz/mt5-tokenizers".into())
.get("mt5-small.tokenizer.json")?,
Which::Mt5Large => api
.model("lmz/mt5-tokenizers".into())
.get("mt5-large.tokenizer.json")?,
_ => repo.get("tokenizer.json")?,
},
Some(f) => f.into(),
};
let weights_filename = match &args.model_file {
Some(f) => f.split(',').map(|v| v.into()).collect::<Vec<_>>(),
None => {
if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
} else {
vec![repo.get("model.safetensors")?]
}
}
}; };
let config = std::fs::read_to_string(config_filename)?; let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?; let mut config: t5::Config = serde_json::from_str(&config)?;