diff --git a/candle-examples/examples/t5/README.md b/candle-examples/examples/t5/README.md index 66952395..c6ea2125 100644 --- a/candle-examples/examples/t5/README.md +++ b/candle-examples/examples/t5/README.md @@ -1,17 +1,25 @@ # candle-t5 -Generates embeddings using a T5 model. It doesn't support generation yet. +## Encoder-decoder example: ```bash -$ cargo run --example t5 -- --model-id t5-large --prompt 'how tall is obama' --n 1 -Loaded and encoded 2.014244792s -[[[-0.3174, -0.1462, 0.0065, ..., -0.0579, -0.0581, 0.1387], - [-0.2905, -0.1945, -0.0685, ..., -0.2457, -0.5137, -0.1760], - [-0.0591, -0.0213, -0.0241, ..., -0.0210, 0.0491, -0.0300], - ... - [-0.4333, 0.0027, -0.0609, ..., 0.3069, -0.2252, 0.3306], - [-0.1458, 0.1323, -0.0138, ..., 0.3000, -0.4550, -0.0384], - [ 0.0397, 0.0485, -0.2373, ..., 0.2578, -0.2650, -0.4356]]] -Tensor[[1, 9, 1024], f32] -Took 2.1363425s -``` \ No newline at end of file +$ cargo run --example t5 -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode +... +Running on CPU, to run on GPU, build this example with `--features cuda` + Eine schöne Kerze. +9 tokens generated (2.42 token/s) +``` + +## Sentence embedding example: + +```bash +$ cargo run --example t5 -- --model-id "t5-small" --prompt "A beautiful candle." +... +[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265], + [-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164], + [ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962], + [-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990], + [ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]] +Tensor[[1, 5, 512], f32] +Took 303.766583ms +``` diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 1e182974..00291609 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -3,18 +3,22 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; +use std::io::Write; +use std::path::PathBuf; + use candle_transformers::models::t5; use anyhow::{anyhow, Error as E, Result}; -use candle::{DType, Tensor}; +use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; use clap::Parser; use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; use tokenizers::Tokenizer; const DTYPE: DType = DType::F32; -#[derive(Parser, Debug)] +#[derive(Parser, Debug, Clone)] #[command(author, version, about, long_about = None)] struct Args { /// Run on CPU rather than on GPU. @@ -36,7 +40,11 @@ struct Args { #[arg(long)] revision: Option, - /// Compute embeddings for this prompt, otherwise compute sentence similarities. + /// Enable decoding. + #[arg(long)] + decode: bool, + + /// Use this prompt, otherwise compute sentence similarities. #[arg(long)] prompt: Option, @@ -49,12 +57,18 @@ struct Args { normalize_embeddings: bool, } -impl Args { - fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> { - let device = candle_examples::device(self.cpu)?; +struct T5ModelBuilder { + device: Device, + config: t5::Config, + weights_filename: PathBuf, +} + +impl T5ModelBuilder { + pub fn load(args: &Args) -> Result<(Self, Tokenizer)> { + let device = candle_examples::device(args.cpu)?; let default_model = "t5-small".to_string(); let default_revision = "refs/pr/15".to_string(); - let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { + let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) { (Some(model_id), Some(revision)) => (model_id, revision), (Some(model_id), None) => (model_id, "main".to_string()), (None, Some(revision)) => (default_model, revision), @@ -62,7 +76,7 @@ impl Args { }; let repo = Repo::with_revision(model_id, RepoType::Model, revision); - let (config_filename, tokenizer_filename, weights_filename) = if self.offline { + let (config_filename, tokenizer_filename, weights_filename) = if args.offline { let cache = Cache::default().repo(repo); ( cache @@ -87,18 +101,36 @@ impl Args { let config = std::fs::read_to_string(config_filename)?; let config: t5::Config = serde_json::from_str(&config)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + Ok(( + Self { + device, + config, + weights_filename, + }, + tokenizer, + )) + } - let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; + pub fn build_encoder(&self) -> Result { + let weights = + unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? }; let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); - let model = t5::T5EncoderModel::load(vb, &config)?; - Ok((model, tokenizer)) + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device); + Ok(t5::T5EncoderModel::load(vb, &self.config)?) + } + + pub fn build_conditional_generation(&self) -> Result { + let weights = + unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? }; + let weights = weights.deserialize()?; + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device); + Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) } } fn main() -> Result<()> { let args = Args::parse(); - let (model, mut tokenizer) = args.build_model_and_tokenizer()?; + let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?; let tokenizer = tokenizer .with_padding(None) .with_truncation(None) @@ -110,17 +142,51 @@ fn main() -> Result<()> { .map_err(E::msg)? .get_ids() .to_vec(); - let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?; - for idx in 0..args.n { - let start = std::time::Instant::now(); - let ys = model.forward(&token_ids)?; - if idx == 0 { - println!("{ys}"); + let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?; + if !args.decode { + let model = builder.build_encoder()?; + for idx in 0..args.n { + let start = std::time::Instant::now(); + let ys = model.forward(&input_token_ids)?; + if idx == 0 { + println!("{ys}"); + } + println!("Took {:?}", start.elapsed()); } - println!("Took {:?}", start.elapsed()); + } else { + let model = builder.build_conditional_generation()?; + let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec(); + let mut logits_processor = LogitsProcessor::new(299792458, None, None); + let start = std::time::Instant::now(); + + for _index in 0.. { + if output_token_ids.len() > 512 { + break; + } + let decoder_token_ids = + Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?; + let logits = model.forward(&input_token_ids, &decoder_token_ids)?; + let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?; + if (next_token_id as usize) == builder.config.eos_token_id { + break; + } + output_token_ids.push(next_token_id); + if let Some(text) = tokenizer.id_to_token(next_token_id) { + let text = text.replace('▁', " ").replace("<0x0A>", "\n"); + print!("{text}"); + std::io::stdout().flush()?; + } + } + let dt = start.elapsed(); + println!( + "\n{} tokens generated ({:.2} token/s)\n", + tokens.len(), + tokens.len() as f64 / dt.as_secs_f64(), + ); } } None => { + let model = builder.build_encoder()?; let sentences = [ "The cat sits outside", "A man is playing guitar", diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index de7de496..c35dea0b 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -18,6 +18,21 @@ fn default_use_cache() -> bool { true } +fn get_mask(size: usize, device: &Device) -> Result { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + let result = Tensor::from_slice(&mask, (size, size), device)?; + Ok(result) +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { vocab_size: usize, @@ -40,8 +55,8 @@ pub struct Config { is_encoder_decoder: bool, #[serde(default = "default_use_cache")] use_cache: bool, - pad_token_id: usize, - eos_token_id: usize, + pub pad_token_id: usize, + pub eos_token_id: usize, } impl Default for Config { @@ -233,13 +248,13 @@ struct T5Attention { } impl T5Attention { - fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result { + fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result { let inner_dim = cfg.num_heads * cfg.d_kv; let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?; let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?; let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?; let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?; - let relative_attention_bias = if h { + let relative_attention_bias = if has_relative_attention_bias { let emb = embedding( cfg.relative_attention_num_buckets, cfg.num_heads, @@ -267,26 +282,46 @@ impl T5Attention { &self, xs: &Tensor, position_bias: Option<&Tensor>, + key_value_states: Option<&Tensor>, + mask: Option<&Tensor>, ) -> Result<(Tensor, Option)> { - // TODO: Apply the mask(s)? + // Performs Self-attention (if key_value_states is None) or attention + // over source sentence (provided by key_value_states). // TODO: kv caching. - let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?); + let kv_input = match key_value_states { + None => xs, + Some(key_value_states) => key_value_states, + }; + let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?); + let kv_len = kv_input.dim(1)?; let q = self.q.forward(xs)?; - let k = self.k.forward(xs)?; - let v = self.v.forward(xs)?; + let k = self.k.forward(kv_input)?; + let v = self.v.forward(kv_input)?; let q = q - .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? + .reshape((b_sz, q_len, self.n_heads, self.d_kv))? .transpose(1, 2)? .contiguous()?; let k = k - .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? .transpose(1, 2)? .contiguous()?; let v = v - .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? .transpose(1, 2)? .contiguous()?; + // TODO: Use flash_attn. let scores = q.matmul(&k.t()?)?; + let scores = match mask { + None => scores, + Some(mask) => masked_fill( + &scores, + &mask + .unsqueeze(0)? + .unsqueeze(0)? + .repeat((b_sz, self.n_heads))?, + f32::NEG_INFINITY, + )?, + }; let (scores, position_bias) = match position_bias { Some(position_bias) => ( @@ -296,14 +331,12 @@ impl T5Attention { None => match &self.relative_attention_bias { None => (scores, None), Some(relative_attention_bias) => { - let query_length = seq_len; - let key_length = seq_len; // This only handles the bidirectional case. let num_buckets = self.relative_attention_num_buckets as u32 / 2; let max_exact = num_buckets / 2; - let relative_position = (0..query_length as u32) + let relative_position = (0..q_len as u32) .map(|i| { - (0..key_length as u32) + (0..kv_len as u32) .map(|j| { if i < j { if j - i < max_exact { @@ -348,7 +381,7 @@ impl T5Attention { let attn_output = attn_weights.matmul(&v)?; let attn_output = attn_output .transpose(1, 2)? - .reshape((b_sz, seq_len, self.inner_dim))?; + .reshape((b_sz, q_len, self.inner_dim))?; let attn_output = self.o.forward(&attn_output)?; Ok((attn_output, position_bias)) } @@ -375,24 +408,49 @@ impl T5LayerSelfAttention { &self, xs: &Tensor, position_bias: Option<&Tensor>, + mask: Option<&Tensor>, ) -> Result<(Tensor, Option)> { let normed_xs = self.layer_norm.forward(xs)?; - let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?; + let (ys, position_bias) = + self.self_attention + .forward(&normed_xs, position_bias, None, mask)?; let ys = (xs + ys)?; Ok((ys, position_bias)) } } #[derive(Debug)] -struct T5LayerCrossAttention {} +struct T5LayerCrossAttention { + cross_attention: T5Attention, + layer_norm: T5LayerNorm, +} impl T5LayerCrossAttention { - fn load(_vb: VarBuilder, _cfg: &Config) -> Result { - todo!() + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let cross_attention = T5Attention::load(false, vb.pp("EncDecAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + Ok(Self { + cross_attention, + layer_norm, + }) } - fn forward(&self, _xs: &Tensor) -> Result { - todo!() + fn forward( + &self, + hidden_states: &Tensor, + position_bias: Option<&Tensor>, + key_value_states: &Tensor, + ) -> Result<(Tensor, Option)> { + let normed_hidden_states = self.layer_norm.forward(hidden_states)?; + let (ys, position_bias) = self.cross_attention.forward( + &normed_hidden_states, + position_bias, + Some(key_value_states), + None, + )?; + let ys = (hidden_states + ys)?; + Ok((ys, position_bias)) } } @@ -425,11 +483,17 @@ impl T5Block { &self, xs: &Tensor, position_bias: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, ) -> Result<(Tensor, Option)> { - let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?; + // TODO: Cache masks + let mask = match self.cross_attn.is_some() { + true => Some(get_mask(xs.dim(1)?, xs.device())?), + false => None, + }; + let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?; // TODO: clamp for f16? if let Some(cross_attn) = &self.cross_attn { - xs = cross_attn.forward(&xs)?; + (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?; // TODO: clamp for f16? } let xs = self.ff.forward(&xs)?; @@ -462,13 +526,20 @@ impl T5Stack { }) } - fn forward(&self, input_ids: &Tensor) -> Result { + fn forward( + &self, + input_ids: &Tensor, + encoder_hidden_states: Option<&Tensor>, + ) -> Result { let input_embeds = self.shared.as_ref().forward(input_ids)?; let mut hidden_states = input_embeds; let mut position_bias = None; for block in self.block.iter() { - (hidden_states, position_bias) = - block.forward(&hidden_states, position_bias.as_ref())? + (hidden_states, position_bias) = block.forward( + &hidden_states, + position_bias.as_ref(), + encoder_hidden_states, + )? } self.final_layer_norm.forward(&hidden_states) } @@ -492,7 +563,61 @@ impl T5EncoderModel { } pub fn forward(&self, input_ids: &Tensor) -> Result { - self.encoder.forward(input_ids) + self.encoder.forward(input_ids, None) + } + + pub fn device(&self) -> &Device { + &self.device + } +} + +#[derive(Debug)] +pub struct T5ForConditionalGeneration { + encoder: T5Stack, + decoder: T5Stack, + shared: Arc, + device: Device, +} + +impl T5ForConditionalGeneration { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result { + assert!(cfg.is_encoder_decoder); + let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Arc::new(shared); + + let mut encoder_cfg = cfg.clone(); + encoder_cfg.is_decoder = false; + encoder_cfg.use_cache = false; + encoder_cfg.is_encoder_decoder = false; + let encoder = T5Stack::load(vb.pp("encoder"), &shared, &encoder_cfg)?; + + let mut decoder_cfg = cfg.clone(); + decoder_cfg.is_decoder = true; + decoder_cfg.is_encoder_decoder = false; + decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers); + let decoder = T5Stack::load(vb.pp("decoder"), &shared, &decoder_cfg)?; + + Ok(Self { + encoder, + decoder, + shared, + device: vb.device().clone(), + }) + } + + pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result { + let encoder_output = self.encoder.forward(input_ids, None)?; + let decoder_output = self + .decoder + .forward(decoder_input_ids, Some(&encoder_output))?; + let sequence_output = decoder_output + .narrow(1, decoder_output.dim(1)? - 1, 1)? + .squeeze(1)?; + // TODO: check cfg.tie_word_embeddings to load from model instead. + let lm_head_weights = self.shared.embeddings().t()?; + let output = sequence_output.matmul(&lm_head_weights)?; + // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5) + Ok(output) } pub fn device(&self) -> &Device {