diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index 2bc1fa2e..c3459004 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -218,12 +218,65 @@ fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> R Ok(()) } +fn run_quantize_safetensors( + in_file: std::path::PathBuf, + out_file: std::path::PathBuf, + q: Quantization, +) -> Result<()> { + let mut out_file = std::fs::File::create(out_file)?; + let tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?; + println!("tensors: {}", tensors.len()); + + let quantize_fn = match q { + Quantization::Q4_0 => QTensor::quantize::, + Quantization::Q4_1 => QTensor::quantize::, + Quantization::Q5_0 => QTensor::quantize::, + Quantization::Q5_1 => QTensor::quantize::, + Quantization::Q8_0 => QTensor::quantize::, + Quantization::Q8_1 => QTensor::quantize::, + Quantization::Q2k => QTensor::quantize::, + Quantization::Q3k => QTensor::quantize::, + Quantization::Q4k => QTensor::quantize::, + Quantization::Q5k => QTensor::quantize::, + Quantization::Q6k => QTensor::quantize::, + Quantization::Q8k => QTensor::quantize::, + Quantization::F16 => QTensor::quantize::, + Quantization::F32 => QTensor::quantize::, + }; + + let qtensors = tensors + .into_par_iter() + .map(|(name, tensor)| { + println!(" quantizing {name} {tensor:?}"); + let should_quantize = tensor.rank() == 2 && tensor.dim(0)? % 256 == 0; + let tensor = if should_quantize { + quantize_fn(&tensor)? + } else { + QTensor::quantize::(&tensor)? + }; + Ok((name, tensor)) + }) + .collect::>>()?; + let qtensors = qtensors + .iter() + .map(|(k, v)| (k.as_str(), v)) + .collect::>(); + gguf_file::write(&mut out_file, &[], &qtensors)?; + Ok(()) +} + fn run_quantize( in_file: std::path::PathBuf, out_file: std::path::PathBuf, q: Quantization, qmode: QuantizationMode, ) -> Result<()> { + if let Some(extension) = in_file.extension() { + if extension == "safetensors" { + return run_quantize_safetensors(in_file, out_file, q); + } + } + // Open the out file early so as to fail directly on missing directories etc. let mut out_file = std::fs::File::create(out_file)?; let mut in_ = std::fs::File::open(&in_file)?; diff --git a/candle-examples/examples/quantized-t5/README.md b/candle-examples/examples/quantized-t5/README.md new file mode 100644 index 00000000..1f6b99eb --- /dev/null +++ b/candle-examples/examples/quantized-t5/README.md @@ -0,0 +1,17 @@ +# candle-quantized-t5 + +This example uses a quantized version of the t5 model. + +```bash +$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle." +... + Eine schöne Kerze. +``` + +The weight file is automatically retrieved from the hub. It is also possible to +generate quantized weight files from the original safetensors file by using the +`tensor-tools` command line utility via: + +```bash +cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf +``` diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs new file mode 100644 index 00000000..86d3762e --- /dev/null +++ b/candle-examples/examples/quantized-t5/main.rs @@ -0,0 +1,186 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; +use std::io::Write; +use std::path::PathBuf; + +use candle_transformers::models::quantized_t5 as t5; + +use anyhow::{Error as E, Result}; +use candle::{Device, Tensor}; +use candle_transformers::generation::LogitsProcessor; +use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +#[derive(Parser, Debug, Clone)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The model repository to use on the HuggingFace hub. + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + #[arg(long)] + weight_file: Option, + + // Enable/disable decoding. + #[arg(long, default_value = "false")] + disable_cache: bool, + + /// Use this prompt, otherwise compute sentence similarities. + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +struct T5ModelBuilder { + device: Device, + config: t5::Config, + weights_filename: PathBuf, +} + +impl T5ModelBuilder { + pub fn load(args: &Args) -> Result<(Self, Tokenizer)> { + let device = Device::Cpu; + let default_model = "lmz/candle-quantized-t5".to_string(); + let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, "main".to_string()), + }; + + let repo = Repo::with_revision(model_id, RepoType::Model, revision); + let api = Api::new()?; + let api = api.repo(repo); + let config_filename = api.get("config.json")?; + let tokenizer_filename = api.get("tokenizer.json")?; + let weights_filename = match &args.weight_file { + Some(filename) => std::path::PathBuf::from(filename), + None => api.get("model.gguf")?, + }; + let config = std::fs::read_to_string(config_filename)?; + let mut config: t5::Config = serde_json::from_str(&config)?; + config.use_cache = !args.disable_cache; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + Ok(( + Self { + device, + config, + weights_filename, + }, + tokenizer, + )) + } + + pub fn build_model(&self) -> Result { + let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?; + Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) + } +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?; + let device = &builder.device; + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; + let tokens = tokenizer + .encode(args.prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let mut model = builder.build_model()?; + let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec(); + let temperature = if args.temperature <= 0. { + None + } else { + Some(args.temperature) + }; + let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p); + let encoder_output = model.encode(&input_token_ids)?; + let start = std::time::Instant::now(); + + for index in 0.. { + if output_token_ids.len() > 512 { + break; + } + let decoder_token_ids = if index == 0 || !builder.config.use_cache { + Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)? + } else { + let last_token = *output_token_ids.last().unwrap(); + Tensor::new(&[last_token], device)?.unsqueeze(0)? + }; + let logits = model + .decode(&decoder_token_ids, &encoder_output)? + .squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &output_token_ids[start_at..], + )? + }; + + let next_token_id = logits_processor.sample(&logits)?; + if next_token_id as usize == builder.config.eos_token_id { + break; + } + output_token_ids.push(next_token_id); + if let Some(text) = tokenizer.id_to_token(next_token_id) { + let text = text.replace('▁', " ").replace("<0x0A>", "\n"); + print!("{text}"); + std::io::stdout().flush()?; + } + } + let dt = start.elapsed(); + println!( + "\n{} tokens generated ({:.2} token/s)\n", + output_token_ids.len(), + output_token_ids.len() as f64 / dt.as_secs_f64(), + ); + Ok(()) +} diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index c14500ba..a10c3b80 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -15,6 +15,21 @@ pub struct VarBuilder { } impl VarBuilder { + pub fn from_gguf>(p: P) -> Result { + let mut file = std::fs::File::open(p)?; + let content = candle::quantized::gguf_file::Content::read(&mut file)?; + let mut data = std::collections::HashMap::new(); + for tensor_name in content.tensor_infos.keys() { + let tensor = content.tensor(&mut file, tensor_name)?; + data.insert(tensor_name.to_string(), Arc::new(tensor)); + } + Ok(Self { + data: Arc::new(data), + path: Vec::new(), + device: Device::Cpu, + }) + } + fn pp(&self, s: S) -> Self { let mut path = self.path.clone(); path.push(s.to_string()); @@ -87,7 +102,7 @@ struct QMatMul { impl QMatMul { fn new(out_dim: usize, in_dim: usize, vb: VarBuilder) -> Result { - let ws = vb.get((out_dim, in_dim), "weight")?; + let ws = vb.get((in_dim, out_dim), "weight")?; let inner = candle::quantized::QMatMul::from_arc(ws); let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); Ok(Self { inner, span })