From 8696f64baef3da878a6974b1fa82cba406789f13 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Tue, 19 Sep 2023 12:36:15 -0700 Subject: [PATCH] Fix T5 kv cache (#899) * Fix T5 kv cache * Add argument for decoder prompt * Fix range --- candle-examples/examples/t5/main.rs | 26 ++++++++++++++++++++------ candle-transformers/src/models/t5.rs | 7 ++++++- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 117aed13..f5972754 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -46,12 +46,16 @@ struct Args { // Enable/disable decoding. #[arg(long, default_value = "false")] - use_cache: bool, + disable_cache: bool, /// Use this prompt, otherwise compute sentence similarities. #[arg(long)] prompt: Option, + /// If set along with --decode, will use this prompt to initialize the decoder. + #[arg(long)] + decoder_prompt: Option, + /// L2 normalization for embeddings. #[arg(long, default_value = "true")] normalize_embeddings: bool, @@ -116,7 +120,7 @@ impl T5ModelBuilder { }; let config = std::fs::read_to_string(config_filename)?; let mut config: t5::Config = serde_json::from_str(&config)?; - config.use_cache = args.use_cache; + config.use_cache = !args.disable_cache; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; Ok(( Self { @@ -170,6 +174,16 @@ fn main() -> Result<()> { } else { let mut model = builder.build_conditional_generation()?; let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec(); + if let Some(decoder_prompt) = &args.decoder_prompt { + print!("{decoder_prompt}"); + output_token_ids.extend( + tokenizer + .encode(decoder_prompt.to_string(), false) + .map_err(E::msg)? + .get_ids() + .to_vec(), + ); + } let temperature = if args.temperature <= 0. { None } else { @@ -195,11 +209,11 @@ fn main() -> Result<()> { let logits = if args.repeat_penalty == 1. { logits } else { - let start_at = tokens.len().saturating_sub(args.repeat_last_n); + let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n); candle_transformers::utils::apply_repeat_penalty( &logits, args.repeat_penalty, - &tokens[start_at..], + &output_token_ids[start_at..], )? }; @@ -217,8 +231,8 @@ fn main() -> Result<()> { let dt = start.elapsed(); println!( "\n{} tokens generated ({:.2} token/s)\n", - tokens.len(), - tokens.len() as f64 / dt.as_secs_f64(), + output_token_ids.len(), + output_token_ids.len() as f64 / dt.as_secs_f64(), ); } } diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 2ffc2ee1..b1f3a3aa 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -348,9 +348,14 @@ impl T5Attention { None => (scores, None), Some(relative_attention_bias) => { // This only handles the bidirectional case. + let kv_len = k.dim(2)?; + let (q_start, q_end) = match self.use_cache { + true => ((kv_len - q_len) as u32, kv_len as u32), + false => (0_u32, kv_len as u32), + }; let num_buckets = self.relative_attention_num_buckets as u32 / 2; let max_exact = num_buckets / 2; - let relative_position = (0..q_len as u32) + let relative_position = (q_start..q_end) .map(|i| { (0..kv_len as u32) .map(|j| {