mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add more quantized flan t5 variants (#923)
* Add the quantized flan-t5-large variant. * Add more sizes.
This commit is contained in:
@ -11,10 +11,20 @@ use candle_transformers::models::quantized_t5 as t5;
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{Device, Tensor};
|
use candle::{Device, Tensor};
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
T5Small,
|
||||||
|
FlanT5Small,
|
||||||
|
FlanT5Base,
|
||||||
|
FlanT5Large,
|
||||||
|
FlanT5Xl,
|
||||||
|
FlanT5Xxl,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -55,6 +65,10 @@ struct Args {
|
|||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 64)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "t5-small")]
|
||||||
|
which: Which,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct T5ModelBuilder {
|
struct T5ModelBuilder {
|
||||||
@ -77,11 +91,25 @@ impl T5ModelBuilder {
|
|||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let api = api.repo(repo);
|
let api = api.repo(repo);
|
||||||
let config_filename = api.get("config.json")?;
|
let config_filename = match args.which {
|
||||||
|
Which::T5Small => api.get("config.json")?,
|
||||||
|
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
|
||||||
|
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
|
||||||
|
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
|
||||||
|
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
|
||||||
|
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
|
||||||
|
};
|
||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
let weights_filename = match &args.weight_file {
|
let weights_filename = match &args.weight_file {
|
||||||
Some(filename) => std::path::PathBuf::from(filename),
|
Some(filename) => std::path::PathBuf::from(filename),
|
||||||
None => api.get("model.gguf")?,
|
None => match args.which {
|
||||||
|
Which::T5Small => api.get("model.gguf")?,
|
||||||
|
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
|
||||||
|
Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
|
||||||
|
Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
|
||||||
|
Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
|
||||||
|
Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||||
|
Reference in New Issue
Block a user