T5 quantized example (#922)

* Load gguf files for the quantized t5.

* Add the quantized t5 example.

* Allow for loading local files.

* Add some support for quantizing safetensor files.

* Transpose before quantizing.

* Quantized t5.

* Retrieve the weights from the hub.
This commit is contained in:
Laurent Mazare
2023-09-21 12:33:15 +01:00
committed by GitHub
parent 2619c4307f
commit 3b557765e8
4 changed files with 272 additions and 1 deletions

View File

@ -0,0 +1,17 @@
# candle-quantized-t5
This example uses a quantized version of the t5 model.
```bash
$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle."
...
Eine schöne Kerze.
```
The weight file is automatically retrieved from the hub. It is also possible to
generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via:
```bash
cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
```

View File

@ -0,0 +1,186 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use std::io::Write;
use std::path::PathBuf;
use candle_transformers::models::quantized_t5 as t5;
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model repository to use on the HuggingFace hub.
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
weight_file: Option<String>,
// Enable/disable decoding.
#[arg(long, default_value = "false")]
disable_cache: bool,
/// Use this prompt, otherwise compute sentence similarities.
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
struct T5ModelBuilder {
device: Device,
config: t5::Config,
weights_filename: PathBuf,
}
impl T5ModelBuilder {
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
let device = Device::Cpu;
let default_model = "lmz/candle-quantized-t5".to_string();
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, "main".to_string()),
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?;
let api = api.repo(repo);
let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = match &args.weight_file {
Some(filename) => std::path::PathBuf::from(filename),
None => api.get("model.gguf")?,
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?;
config.use_cache = !args.disable_cache;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok((
Self {
device,
config,
weights_filename,
},
tokenizer,
))
}
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
let device = &builder.device;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mut model = builder.build_model()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let temperature = if args.temperature <= 0. {
None
} else {
Some(args.temperature)
};
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
let encoder_output = model.encode(&input_token_ids)?;
let start = std::time::Instant::now();
for index in 0.. {
if output_token_ids.len() > 512 {
break;
}
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
} else {
let last_token = *output_token_ids.last().unwrap();
Tensor::new(&[last_token], device)?.unsqueeze(0)?
};
let logits = model
.decode(&decoder_token_ids, &encoder_output)?
.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&output_token_ids[start_at..],
)?
};
let next_token_id = logits_processor.sample(&logits)?;
if next_token_id as usize == builder.config.eos_token_id {
break;
}
output_token_ids.push(next_token_id);
if let Some(text) = tokenizer.id_to_token(next_token_id) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
}
let dt = start.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",
output_token_ids.len(),
output_token_ids.len() as f64 / dt.as_secs_f64(),
);
Ok(())
}