diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 84be0204..03c861c1 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -13,7 +13,6 @@ use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; use tokenizers::Tokenizer; const DTYPE: DType = DType::F32; -const DEFAULT_PROMPT: &str = "Translate English to German: That is good."; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -37,13 +36,17 @@ struct Args { #[arg(long)] revision: Option, - /// Compute embeddings for this prompt or use the DEFAULT_PROMPT. + /// Compute embeddings for this prompt, otherwise compute sentence similarities. #[arg(long)] prompt: Option, /// The number of times to run the prompt. #[arg(long, default_value = "1")] n: usize, + + /// L2 normalization for embeddings. + #[arg(long, default_value = "true")] + normalize_embeddings: bool, } impl Args { @@ -95,28 +98,95 @@ impl Args { fn main() -> Result<()> { let args = Args::parse(); - let start = std::time::Instant::now(); - let (model, mut tokenizer) = args.build_model_and_tokenizer()?; - let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string()); let tokenizer = tokenizer .with_padding(None) .with_truncation(None) .map_err(E::msg)?; - let tokens = tokenizer - .encode(prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?; - println!("Loaded and encoded {:?}", start.elapsed()); - for idx in 0..args.n { - let start = std::time::Instant::now(); - let ys = model.forward(&token_ids)?; - if idx == 0 { - println!("{ys}"); + match args.prompt { + Some(prompt) => { + let tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?; + for idx in 0..args.n { + let start = std::time::Instant::now(); + let ys = model.forward(&token_ids)?; + if idx == 0 { + println!("{ys}"); + } + println!("Took {:?}", start.elapsed()); + } + } + None => { + let sentences = [ + "The cat sits outside", + "A man is playing guitar", + "I love pasta", + "The new movie is awesome", + "The cat plays in the garden", + "A woman watches TV", + "The new movie is so great", + "Do you like pizza?", + ]; + let n_sentences = sentences.len(); + if let Some(pp) = tokenizer.get_padding_mut() { + pp.strategy = tokenizers::PaddingStrategy::BatchLongest + } else { + let pp = tokenizers::PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + tokenizer.with_padding(Some(pp)); + } + let tokens = tokenizer + .encode_batch(sentences.to_vec(), true) + .map_err(E::msg)?; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Ok(Tensor::new(tokens.as_slice(), model.device())?) + }) + .collect::>>()?; + + let token_ids = Tensor::stack(&token_ids, 0)?; + println!("running inference on batch {:?}", token_ids.shape()); + let embeddings = model.forward(&token_ids)?; + println!("generated embeddings {:?}", embeddings.shape()); + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; + let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; + let embeddings = if args.normalize_embeddings { + normalize_l2(&embeddings)? + } else { + embeddings + }; + println!("pooled embeddings {:?}", embeddings.shape()); + + let mut similarities = vec![]; + for i in 0..n_sentences { + let e_i = embeddings.get(i)?; + for j in (i + 1)..n_sentences { + let e_j = embeddings.get(j)?; + let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::()?; + let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::()?; + let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::()?; + let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); + similarities.push((cosine_similarity, i, j)) + } + } + similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); + for &(score, i, j) in similarities[..5].iter() { + println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) + } } - println!("Took {:?}", start.elapsed()); } Ok(()) } + +pub fn normalize_l2(v: &Tensor) -> Result { + Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) +} diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 691817d1..325eb752 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -245,7 +245,10 @@ impl T5Attention { let scores = q.matmul(&k.t()?)?; let (scores, position_bias) = match position_bias { - Some(position_bias) => ((scores + position_bias)?, Some(position_bias.clone())), + Some(position_bias) => ( + scores.broadcast_add(position_bias)?, + Some(position_bias.clone()), + ), None => match &self.relative_attention_bias { None => (scores, None), Some(relative_attention_bias) => { @@ -291,7 +294,7 @@ impl T5Attention { .forward(&relative_buckets)? .permute((2, 0, 1))? .unsqueeze(0)?; - ((scores + &position_bias)?, Some(position_bias)) + (scores.broadcast_add(&position_bias)?, Some(position_bias)) // TODO: position_bias_masked? } },