mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add some examples using the MT5 variants. (#1963)
This commit is contained in:
@ -23,6 +23,9 @@ enum Which {
|
||||
T5Base,
|
||||
T5Small,
|
||||
T5_3B,
|
||||
Mt5Base,
|
||||
Mt5Small,
|
||||
Mt5Large,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
@ -43,6 +46,15 @@ struct Args {
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
model_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// Enable decoding.
|
||||
#[arg(long)]
|
||||
decode: bool,
|
||||
@ -97,6 +109,9 @@ impl T5ModelBuilder {
|
||||
Which::T5Base => ("t5-base", "main"),
|
||||
Which::T5Small => ("t5-small", "refs/pr/15"),
|
||||
Which::T5_3B => ("t5-3b", "main"),
|
||||
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
|
||||
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
|
||||
Which::Mt5Large => ("google/mt5-large", "refs/pr/2"),
|
||||
};
|
||||
let default_model = default_model.to_string();
|
||||
let default_revision = default_revision.to_string();
|
||||
@ -109,14 +124,35 @@ impl T5ModelBuilder {
|
||||
|
||||
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = api.get("config.json")?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
|
||||
{
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
} else {
|
||||
vec![api.get("model.safetensors")?]
|
||||
let repo = api.repo(repo);
|
||||
let config_filename = match &args.config_file {
|
||||
None => repo.get("config.json")?,
|
||||
Some(f) => f.into(),
|
||||
};
|
||||
let tokenizer_filename = match &args.tokenizer_file {
|
||||
None => match args.which {
|
||||
Which::Mt5Base => api
|
||||
.model("lmz/mt5-tokenizers".into())
|
||||
.get("mt5-base.tokenizer.json")?,
|
||||
Which::Mt5Small => api
|
||||
.model("lmz/mt5-tokenizers".into())
|
||||
.get("mt5-small.tokenizer.json")?,
|
||||
Which::Mt5Large => api
|
||||
.model("lmz/mt5-tokenizers".into())
|
||||
.get("mt5-large.tokenizer.json")?,
|
||||
_ => repo.get("tokenizer.json")?,
|
||||
},
|
||||
Some(f) => f.into(),
|
||||
};
|
||||
let weights_filename = match &args.model_file {
|
||||
Some(f) => f.split(',').map(|v| v.into()).collect::<Vec<_>>(),
|
||||
None => {
|
||||
if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
} else {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
}
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
|
Reference in New Issue
Block a user