mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Implement T5 decoding (#864)
* Load t5 decoder * Run enc, dec, and lm head, but no cross attn * Cross-attention over key_value_states * New arg for decoder input ids * Add mask, don't forward position biases through decoder * Update t5 examples * Clippy + rustfmt
This commit is contained in:
@ -1,17 +1,25 @@
|
||||
# candle-t5
|
||||
|
||||
Generates embeddings using a T5 model. It doesn't support generation yet.
|
||||
## Encoder-decoder example:
|
||||
|
||||
```bash
|
||||
$ cargo run --example t5 -- --model-id t5-large --prompt 'how tall is obama' --n 1
|
||||
Loaded and encoded 2.014244792s
|
||||
[[[-0.3174, -0.1462, 0.0065, ..., -0.0579, -0.0581, 0.1387],
|
||||
[-0.2905, -0.1945, -0.0685, ..., -0.2457, -0.5137, -0.1760],
|
||||
[-0.0591, -0.0213, -0.0241, ..., -0.0210, 0.0491, -0.0300],
|
||||
...
|
||||
[-0.4333, 0.0027, -0.0609, ..., 0.3069, -0.2252, 0.3306],
|
||||
[-0.1458, 0.1323, -0.0138, ..., 0.3000, -0.4550, -0.0384],
|
||||
[ 0.0397, 0.0485, -0.2373, ..., 0.2578, -0.2650, -0.4356]]]
|
||||
Tensor[[1, 9, 1024], f32]
|
||||
Took 2.1363425s
|
||||
```
|
||||
$ cargo run --example t5 -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
||||
...
|
||||
Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||
Eine schöne Kerze.
|
||||
9 tokens generated (2.42 token/s)
|
||||
```
|
||||
|
||||
## Sentence embedding example:
|
||||
|
||||
```bash
|
||||
$ cargo run --example t5 -- --model-id "t5-small" --prompt "A beautiful candle."
|
||||
...
|
||||
[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
|
||||
[-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],
|
||||
[ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962],
|
||||
[-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990],
|
||||
[ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]]
|
||||
Tensor[[1, 5, 512], f32]
|
||||
Took 303.766583ms
|
||||
```
|
||||
|
@ -3,18 +3,22 @@ extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use candle_transformers::models::t5;
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::{DType, Tensor};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
@ -36,7 +40,11 @@ struct Args {
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// Compute embeddings for this prompt, otherwise compute sentence similarities.
|
||||
/// Enable decoding.
|
||||
#[arg(long)]
|
||||
decode: bool,
|
||||
|
||||
/// Use this prompt, otherwise compute sentence similarities.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
@ -49,12 +57,18 @@ struct Args {
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> {
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
struct T5ModelBuilder {
|
||||
device: Device,
|
||||
config: t5::Config,
|
||||
weights_filename: PathBuf,
|
||||
}
|
||||
|
||||
impl T5ModelBuilder {
|
||||
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let default_model = "t5-small".to_string();
|
||||
let default_revision = "refs/pr/15".to_string();
|
||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
@ -62,7 +76,7 @@ impl Args {
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
|
||||
let cache = Cache::default().repo(repo);
|
||||
(
|
||||
cache
|
||||
@ -87,18 +101,36 @@ impl Args {
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: t5::Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
Ok((
|
||||
Self {
|
||||
device,
|
||||
config,
|
||||
weights_filename,
|
||||
},
|
||||
tokenizer,
|
||||
))
|
||||
}
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
|
||||
let weights =
|
||||
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let model = t5::T5EncoderModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
|
||||
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let weights =
|
||||
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
@ -110,17 +142,51 @@ fn main() -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&token_ids)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?;
|
||||
if !args.decode {
|
||||
let model = builder.build_encoder()?;
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&input_token_ids)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
} else {
|
||||
let model = builder.build_conditional_generation()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, None, None);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for _index in 0.. {
|
||||
if output_token_ids.len() > 512 {
|
||||
break;
|
||||
}
|
||||
let decoder_token_ids =
|
||||
Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input_token_ids, &decoder_token_ids)?;
|
||||
let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?;
|
||||
if (next_token_id as usize) == builder.config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
output_token_ids.push(next_token_id);
|
||||
if let Some(text) = tokenizer.id_to_token(next_token_id) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
tokens.len(),
|
||||
tokens.len() as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let model = builder.build_encoder()?;
|
||||
let sentences = [
|
||||
"The cat sits outside",
|
||||
"A man is playing guitar",
|
||||
|
Reference in New Issue
Block a user