Add some examples using the MT5 variants. (#1963)

This commit is contained in:
Laurent Mazare
2024-03-29 18:09:29 +01:00
committed by GitHub
parent eb1b27abcd
commit 8ad12a0e81

View File

@ -23,6 +23,9 @@ enum Which {
T5Base,
T5Small,
T5_3B,
Mt5Base,
Mt5Small,
Mt5Large,
}
#[derive(Parser, Debug, Clone)]
@ -43,6 +46,15 @@ struct Args {
#[arg(long)]
revision: Option<String>,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// Enable decoding.
#[arg(long)]
decode: bool,
@ -97,6 +109,9 @@ impl T5ModelBuilder {
Which::T5Base => ("t5-base", "main"),
Which::T5Small => ("t5-small", "refs/pr/15"),
Which::T5_3B => ("t5-3b", "main"),
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
Which::Mt5Large => ("google/mt5-large", "refs/pr/2"),
};
let default_model = default_model.to_string();
let default_revision = default_revision.to_string();
@ -109,14 +124,35 @@ impl T5ModelBuilder {
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
let api = Api::new()?;
let api = api.repo(repo);
let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
{
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
} else {
vec![api.get("model.safetensors")?]
let repo = api.repo(repo);
let config_filename = match &args.config_file {
None => repo.get("config.json")?,
Some(f) => f.into(),
};
let tokenizer_filename = match &args.tokenizer_file {
None => match args.which {
Which::Mt5Base => api
.model("lmz/mt5-tokenizers".into())
.get("mt5-base.tokenizer.json")?,
Which::Mt5Small => api
.model("lmz/mt5-tokenizers".into())
.get("mt5-small.tokenizer.json")?,
Which::Mt5Large => api
.model("lmz/mt5-tokenizers".into())
.get("mt5-large.tokenizer.json")?,
_ => repo.get("tokenizer.json")?,
},
Some(f) => f.into(),
};
let weights_filename = match &args.model_file {
Some(f) => f.split(',').map(|v| v.into()).collect::<Vec<_>>(),
None => {
if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
} else {
vec![repo.get("model.safetensors")?]
}
}
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?;