mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add quantized t5 args for weight and config (#1029)
This commit is contained in:
@ -12,7 +12,7 @@ use anyhow::{Error as E, Result};
|
|||||||
use candle::{Device, Tensor};
|
use candle::{Device, Tensor};
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
@ -42,6 +42,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
weight_file: Option<String>,
|
weight_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
// Enable/disable decoding.
|
// Enable/disable decoding.
|
||||||
#[arg(long, default_value = "false")]
|
#[arg(long, default_value = "false")]
|
||||||
disable_cache: bool,
|
disable_cache: bool,
|
||||||
@ -91,17 +94,20 @@ impl T5ModelBuilder {
|
|||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let api = api.repo(repo);
|
let api = api.repo(repo);
|
||||||
let config_filename = match args.which {
|
let config_filename = match &args.config_file {
|
||||||
Which::T5Small => api.get("config.json")?,
|
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
||||||
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
|
None => match args.which {
|
||||||
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
|
Which::T5Small => api.get("config.json")?,
|
||||||
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
|
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
|
||||||
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
|
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
|
||||||
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
|
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
|
||||||
|
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
|
||||||
|
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
let weights_filename = match &args.weight_file {
|
let weights_filename = match &args.weight_file {
|
||||||
Some(filename) => std::path::PathBuf::from(filename),
|
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
||||||
None => match args.which {
|
None => match args.which {
|
||||||
Which::T5Small => api.get("model.gguf")?,
|
Which::T5Small => api.get("model.gguf")?,
|
||||||
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
|
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
|
||||||
@ -129,6 +135,15 @@ impl T5ModelBuilder {
|
|||||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
||||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result<PathBuf> {
|
||||||
|
let local_filename = std::path::PathBuf::from(filename);
|
||||||
|
if local_filename.exists() {
|
||||||
|
Ok(local_filename)
|
||||||
|
} else {
|
||||||
|
Ok(api.get(filename)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
Reference in New Issue
Block a user