From 3349c892523426a00e16dd094837f5d786754ce1 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Wed, 4 Oct 2023 12:02:49 -0400 Subject: [PATCH] Add quantized t5 args for weight and config (#1029) --- candle-examples/examples/quantized-t5/main.rs | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs index 93a86309..2bc050ee 100644 --- a/candle-examples/examples/quantized-t5/main.rs +++ b/candle-examples/examples/quantized-t5/main.rs @@ -12,7 +12,7 @@ use anyhow::{Error as E, Result}; use candle::{Device, Tensor}; use candle_transformers::generation::LogitsProcessor; use clap::{Parser, ValueEnum}; -use hf_hub::{api::sync::Api, Repo, RepoType}; +use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType}; use tokenizers::Tokenizer; #[derive(Clone, Debug, Copy, ValueEnum)] @@ -42,6 +42,9 @@ struct Args { #[arg(long)] weight_file: Option, + #[arg(long)] + config_file: Option, + // Enable/disable decoding. #[arg(long, default_value = "false")] disable_cache: bool, @@ -91,17 +94,20 @@ impl T5ModelBuilder { let repo = Repo::with_revision(model_id, RepoType::Model, revision); let api = Api::new()?; let api = api.repo(repo); - let config_filename = match args.which { - Which::T5Small => api.get("config.json")?, - Which::FlanT5Small => api.get("config-flan-t5-small.json")?, - Which::FlanT5Base => api.get("config-flan-t5-base.json")?, - Which::FlanT5Large => api.get("config-flan-t5-large.json")?, - Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?, - Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?, + let config_filename = match &args.config_file { + Some(filename) => Self::get_local_or_remote_file(filename, &api)?, + None => match args.which { + Which::T5Small => api.get("config.json")?, + Which::FlanT5Small => api.get("config-flan-t5-small.json")?, + Which::FlanT5Base => api.get("config-flan-t5-base.json")?, + Which::FlanT5Large => api.get("config-flan-t5-large.json")?, + Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?, + Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?, + }, }; let tokenizer_filename = api.get("tokenizer.json")?; let weights_filename = match &args.weight_file { - Some(filename) => std::path::PathBuf::from(filename), + Some(filename) => Self::get_local_or_remote_file(filename, &api)?, None => match args.which { Which::T5Small => api.get("model.gguf")?, Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?, @@ -129,6 +135,15 @@ impl T5ModelBuilder { let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?; Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) } + + fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result { + let local_filename = std::path::PathBuf::from(filename); + if local_filename.exists() { + Ok(local_filename) + } else { + Ok(api.get(filename)?) + } + } } fn main() -> Result<()> {