mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add some examples using the MT5 variants. (#1963)
This commit is contained in:
@ -23,6 +23,9 @@ enum Which {
|
|||||||
T5Base,
|
T5Base,
|
||||||
T5Small,
|
T5Small,
|
||||||
T5_3B,
|
T5_3B,
|
||||||
|
Mt5Base,
|
||||||
|
Mt5Small,
|
||||||
|
Mt5Large,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
@ -43,6 +46,15 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
/// Enable decoding.
|
/// Enable decoding.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
decode: bool,
|
decode: bool,
|
||||||
@ -97,6 +109,9 @@ impl T5ModelBuilder {
|
|||||||
Which::T5Base => ("t5-base", "main"),
|
Which::T5Base => ("t5-base", "main"),
|
||||||
Which::T5Small => ("t5-small", "refs/pr/15"),
|
Which::T5Small => ("t5-small", "refs/pr/15"),
|
||||||
Which::T5_3B => ("t5-3b", "main"),
|
Which::T5_3B => ("t5-3b", "main"),
|
||||||
|
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
|
||||||
|
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
|
||||||
|
Which::Mt5Large => ("google/mt5-large", "refs/pr/2"),
|
||||||
};
|
};
|
||||||
let default_model = default_model.to_string();
|
let default_model = default_model.to_string();
|
||||||
let default_revision = default_revision.to_string();
|
let default_revision = default_revision.to_string();
|
||||||
@ -109,14 +124,35 @@ impl T5ModelBuilder {
|
|||||||
|
|
||||||
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let api = api.repo(repo);
|
let repo = api.repo(repo);
|
||||||
let config_filename = api.get("config.json")?;
|
let config_filename = match &args.config_file {
|
||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
None => repo.get("config.json")?,
|
||||||
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
|
Some(f) => f.into(),
|
||||||
{
|
};
|
||||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
let tokenizer_filename = match &args.tokenizer_file {
|
||||||
} else {
|
None => match args.which {
|
||||||
vec![api.get("model.safetensors")?]
|
Which::Mt5Base => api
|
||||||
|
.model("lmz/mt5-tokenizers".into())
|
||||||
|
.get("mt5-base.tokenizer.json")?,
|
||||||
|
Which::Mt5Small => api
|
||||||
|
.model("lmz/mt5-tokenizers".into())
|
||||||
|
.get("mt5-small.tokenizer.json")?,
|
||||||
|
Which::Mt5Large => api
|
||||||
|
.model("lmz/mt5-tokenizers".into())
|
||||||
|
.get("mt5-large.tokenizer.json")?,
|
||||||
|
_ => repo.get("tokenizer.json")?,
|
||||||
|
},
|
||||||
|
Some(f) => f.into(),
|
||||||
|
};
|
||||||
|
let weights_filename = match &args.model_file {
|
||||||
|
Some(f) => f.split(',').map(|v| v.into()).collect::<Vec<_>>(),
|
||||||
|
None => {
|
||||||
|
if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" {
|
||||||
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
|
} else {
|
||||||
|
vec![repo.get("model.safetensors")?]
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||||
|
Reference in New Issue
Block a user