Add quantized t5 args for weight and config (#1029)

This commit is contained in:
Juarez Bochi
2023-10-04 12:02:49 -04:00
committed by GitHub
parent 11d3687cc6
commit 3349c89252

View File

@ -12,7 +12,7 @@ use anyhow::{Error as E, Result};
use candle::{Device, Tensor}; use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, ValueEnum)] #[derive(Clone, Debug, Copy, ValueEnum)]
@ -42,6 +42,9 @@ struct Args {
#[arg(long)] #[arg(long)]
weight_file: Option<String>, weight_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
// Enable/disable decoding. // Enable/disable decoding.
#[arg(long, default_value = "false")] #[arg(long, default_value = "false")]
disable_cache: bool, disable_cache: bool,
@ -91,17 +94,20 @@ impl T5ModelBuilder {
let repo = Repo::with_revision(model_id, RepoType::Model, revision); let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?; let api = Api::new()?;
let api = api.repo(repo); let api = api.repo(repo);
let config_filename = match args.which { let config_filename = match &args.config_file {
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
None => match args.which {
Which::T5Small => api.get("config.json")?, Which::T5Small => api.get("config.json")?,
Which::FlanT5Small => api.get("config-flan-t5-small.json")?, Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
Which::FlanT5Base => api.get("config-flan-t5-base.json")?, Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
Which::FlanT5Large => api.get("config-flan-t5-large.json")?, Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?, Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?, Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
},
}; };
let tokenizer_filename = api.get("tokenizer.json")?; let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = match &args.weight_file { let weights_filename = match &args.weight_file {
Some(filename) => std::path::PathBuf::from(filename), Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
None => match args.which { None => match args.which {
Which::T5Small => api.get("model.gguf")?, Which::T5Small => api.get("model.gguf")?,
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?, Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
@ -129,6 +135,15 @@ impl T5ModelBuilder {
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?; let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
} }
fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result<PathBuf> {
let local_filename = std::path::PathBuf::from(filename);
if local_filename.exists() {
Ok(local_filename)
} else {
Ok(api.get(filename)?)
}
}
} }
fn main() -> Result<()> { fn main() -> Result<()> {