mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
T5 quantized example (#922)
* Load gguf files for the quantized t5. * Add the quantized t5 example. * Allow for loading local files. * Add some support for quantizing safetensor files. * Transpose before quantizing. * Quantized t5. * Retrieve the weights from the hub.
This commit is contained in:
@ -218,12 +218,65 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize_safetensors(
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
) -> Result<()> {
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
|
||||
println!("tensors: {}", tensors.len());
|
||||
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
|
||||
let qtensors = tensors
|
||||
.into_par_iter()
|
||||
.map(|(name, tensor)| {
|
||||
println!(" quantizing {name} {tensor:?}");
|
||||
let should_quantize = tensor.rank() == 2 && tensor.dim(0)? % 256 == 0;
|
||||
let tensor = if should_quantize {
|
||||
quantize_fn(&tensor)?
|
||||
} else {
|
||||
QTensor::quantize::<f32>(&tensor)?
|
||||
};
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let qtensors = qtensors
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
gguf_file::write(&mut out_file, &[], &qtensors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
qmode: QuantizationMode,
|
||||
) -> Result<()> {
|
||||
if let Some(extension) = in_file.extension() {
|
||||
if extension == "safetensors" {
|
||||
return run_quantize_safetensors(in_file, out_file, q);
|
||||
}
|
||||
}
|
||||
|
||||
// Open the out file early so as to fail directly on missing directories etc.
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut in_ = std::fs::File::open(&in_file)?;
|
||||
|
17
candle-examples/examples/quantized-t5/README.md
Normal file
17
candle-examples/examples/quantized-t5/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# candle-quantized-t5
|
||||
|
||||
This example uses a quantized version of the t5 model.
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle."
|
||||
...
|
||||
Eine schöne Kerze.
|
||||
```
|
||||
|
||||
The weight file is automatically retrieved from the hub. It is also possible to
|
||||
generate quantized weight files from the original safetensors file by using the
|
||||
`tensor-tools` command line utility via:
|
||||
|
||||
```bash
|
||||
cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
```
|
186
candle-examples/examples/quantized-t5/main.rs
Normal file
186
candle-examples/examples/quantized-t5/main.rs
Normal file
@ -0,0 +1,186 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use candle_transformers::models::quantized_t5 as t5;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model repository to use on the HuggingFace hub.
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
// Enable/disable decoding.
|
||||
#[arg(long, default_value = "false")]
|
||||
disable_cache: bool,
|
||||
|
||||
/// Use this prompt, otherwise compute sentence similarities.
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
struct T5ModelBuilder {
|
||||
device: Device,
|
||||
config: t5::Config,
|
||||
weights_filename: PathBuf,
|
||||
}
|
||||
|
||||
impl T5ModelBuilder {
|
||||
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||
let device = Device::Cpu;
|
||||
let default_model = "lmz/candle-quantized-t5".to_string();
|
||||
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, "main".to_string()),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = api.get("config.json")?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = match &args.weight_file {
|
||||
Some(filename) => std::path::PathBuf::from(filename),
|
||||
None => api.get("model.gguf")?,
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
config.use_cache = !args.disable_cache;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
Ok((
|
||||
Self {
|
||||
device,
|
||||
config,
|
||||
weights_filename,
|
||||
},
|
||||
tokenizer,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||
let device = &builder.device;
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mut model = builder.build_model()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
|
||||
let encoder_output = model.encode(&input_token_ids)?;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for index in 0.. {
|
||||
if output_token_ids.len() > 512 {
|
||||
break;
|
||||
}
|
||||
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
|
||||
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||
} else {
|
||||
let last_token = *output_token_ids.last().unwrap();
|
||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||
};
|
||||
let logits = model
|
||||
.decode(&decoder_token_ids, &encoder_output)?
|
||||
.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&output_token_ids[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits)?;
|
||||
if next_token_id as usize == builder.config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
output_token_ids.push(next_token_id);
|
||||
if let Some(text) = tokenizer.id_to_token(next_token_id) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
output_token_ids.len(),
|
||||
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -15,6 +15,21 @@ pub struct VarBuilder {
|
||||
}
|
||||
|
||||
impl VarBuilder {
|
||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let mut file = std::fs::File::open(p)?;
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut file, tensor_name)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
data: Arc::new(data),
|
||||
path: Vec::new(),
|
||||
device: Device::Cpu,
|
||||
})
|
||||
}
|
||||
|
||||
fn pp<S: ToString>(&self, s: S) -> Self {
|
||||
let mut path = self.path.clone();
|
||||
path.push(s.to_string());
|
||||
@ -87,7 +102,7 @@ struct QMatMul {
|
||||
|
||||
impl QMatMul {
|
||||
fn new(out_dim: usize, in_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let ws = vb.get((out_dim, in_dim), "weight")?;
|
||||
let ws = vb.get((in_dim, out_dim), "weight")?;
|
||||
let inner = candle::quantized::QMatMul::from_arc(ws);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Ok(Self { inner, span })
|
||||
|
Reference in New Issue
Block a user