diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs index 86d3762e..93a86309 100644 --- a/candle-examples/examples/quantized-t5/main.rs +++ b/candle-examples/examples/quantized-t5/main.rs @@ -11,10 +11,20 @@ use candle_transformers::models::quantized_t5 as t5; use anyhow::{Error as E, Result}; use candle::{Device, Tensor}; use candle_transformers::generation::LogitsProcessor; -use clap::Parser; +use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; +#[derive(Clone, Debug, Copy, ValueEnum)] +enum Which { + T5Small, + FlanT5Small, + FlanT5Base, + FlanT5Large, + FlanT5Xl, + FlanT5Xxl, +} + #[derive(Parser, Debug, Clone)] #[command(author, version, about, long_about = None)] struct Args { @@ -55,6 +65,10 @@ struct Args { /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 64)] repeat_last_n: usize, + + /// The model size to use. + #[arg(long, default_value = "t5-small")] + which: Which, } struct T5ModelBuilder { @@ -77,11 +91,25 @@ impl T5ModelBuilder { let repo = Repo::with_revision(model_id, RepoType::Model, revision); let api = Api::new()?; let api = api.repo(repo); - let config_filename = api.get("config.json")?; + let config_filename = match args.which { + Which::T5Small => api.get("config.json")?, + Which::FlanT5Small => api.get("config-flan-t5-small.json")?, + Which::FlanT5Base => api.get("config-flan-t5-base.json")?, + Which::FlanT5Large => api.get("config-flan-t5-large.json")?, + Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?, + Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?, + }; let tokenizer_filename = api.get("tokenizer.json")?; let weights_filename = match &args.weight_file { Some(filename) => std::path::PathBuf::from(filename), - None => api.get("model.gguf")?, + None => match args.which { + Which::T5Small => api.get("model.gguf")?, + Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?, + Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?, + Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?, + Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?, + Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?, + }, }; let config = std::fs::read_to_string(config_filename)?; let mut config: t5::Config = serde_json::from_str(&config)?;