diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 4de0aeac..8795faa9 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -5,6 +5,7 @@ use candle_hub::{api::Api, Cache, Repo, RepoType}; use clap::Parser; use serde::Deserialize; use std::collections::HashMap; +use tokenizers::Tokenizer; const DTYPE: DType = DType::F32; @@ -578,6 +579,7 @@ impl BertEncoder { struct BertModel { embeddings: BertEmbeddings, encoder: BertEncoder, + device: Device, } impl BertModel { @@ -600,6 +602,7 @@ impl BertModel { Ok(Self { embeddings, encoder, + device: vb.device.clone(), }) } @@ -628,81 +631,137 @@ struct Args { #[arg(long)] revision: Option, - /// The number of times to run the prompt. - #[arg(long, default_value = "This is an example sentence")] - prompt: String, + /// When set, compute embeddings for this prompt. + #[arg(long)] + prompt: Option, /// The number of times to run the prompt. #[arg(long, default_value = "1")] n: usize, } +impl Args { + async fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> { + let device = if self.cpu { + Device::Cpu + } else { + Device::new_cuda(0)? + }; + let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); + let default_revision = "refs/pr/21".to_string(); + let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + }; + + let repo = Repo::with_revision(model_id, RepoType::Model, revision); + let (config_filename, tokenizer_filename, weights_filename) = if self.offline { + let cache = Cache::default(); + ( + cache + .get(&repo, "config.json") + .ok_or(anyhow!("Missing config file in cache"))?, + cache + .get(&repo, "tokenizer.json") + .ok_or(anyhow!("Missing tokenizer file in cache"))?, + cache + .get(&repo, "model.safetensors") + .ok_or(anyhow!("Missing weights file in cache"))?, + ) + } else { + let api = Api::new()?; + ( + api.get(&repo, "config.json").await?, + api.get(&repo, "tokenizer.json").await?, + api.get(&repo, "model.safetensors").await?, + ) + }; + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; + let weights = weights.deserialize()?; + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device); + let model = BertModel::load(&vb, &config)?; + Ok((model, tokenizer)) + } +} + #[tokio::main] async fn main() -> Result<()> { - use tokenizers::Tokenizer; let start = std::time::Instant::now(); let args = Args::parse(); - let device = if args.cpu { - Device::Cpu + let (model, mut tokenizer) = args.build_model_and_tokenizer().await?; + let device = &model.device; + + if let Some(prompt) = args.prompt { + let tokenizer = tokenizer.with_padding(None).with_truncation(None); + let tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let token_type_ids = token_ids.zeros_like()?; + println!("Loaded and encoded {:?}", start.elapsed()); + for _ in 0..args.n { + let start = std::time::Instant::now(); + let _ys = model.forward(&token_ids, &token_type_ids)?; + println!("Took {:?}", start.elapsed()); + } } else { - Device::new_cuda(0)? - }; - - let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); - let default_revision = "refs/pr/21".to_string(); - let (model_id, revision) = match (args.model_id, args.revision) { - (Some(model_id), Some(revision)) => (model_id, revision), - (Some(model_id), None) => (model_id, "main".to_string()), - (None, Some(revision)) => (default_model, revision), - (None, None) => (default_model, default_revision), - }; - - let repo = Repo::with_revision(model_id, RepoType::Model, revision); - let (config_filename, tokenizer_filename, weights_filename) = if args.offline { - let cache = Cache::default(); - ( - cache - .get(&repo, "config.json") - .ok_or(anyhow!("Missing config file in cache"))?, - cache - .get(&repo, "tokenizer.json") - .ok_or(anyhow!("Missing tokenizer file in cache"))?, - cache - .get(&repo, "model.safetensors") - .ok_or(anyhow!("Missing weights file in cache"))?, - ) - } else { - let api = Api::new()?; - ( - api.get(&repo, "config.json").await?, - api.get(&repo, "tokenizer.json").await?, - api.get(&repo, "model.safetensors").await?, - ) - }; - let config = std::fs::read_to_string(config_filename)?; - let config: Config = serde_json::from_str(&config)?; - let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let tokenizer = tokenizer.with_padding(None).with_truncation(None); - - let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; - let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone()); - let model = BertModel::load(&vb, &config)?; - - let tokens = tokenizer - .encode(args.prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; - let token_type_ids = token_ids.zeros_like()?; - println!("Loaded and encoded {:?}", start.elapsed()); - for _ in 0..args.n { - let start = std::time::Instant::now(); - let _ys = model.forward(&token_ids, &token_type_ids)?; - println!("Took {:?}", start.elapsed()); - // println!("Ys {:?}", ys.shape()); + let sentences = [ + "The cat sits outside", + "A man is playing guitar", + "I love pasta", + "The new movie is awesome", + "The cat plays in the garden", + "A woman watches TV", + "The new movie is so great", + "Do you like pizza?", + ]; + let n_sentences = sentences.len(); + if let Some(pp) = tokenizer.get_padding_mut() { + pp.strategy = tokenizers::PaddingStrategy::BatchLongest + } + let tokens = tokenizer + .encode_batch(sentences.to_vec(), true) + .map_err(E::msg)?; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Ok(Tensor::new(tokens.as_slice(), device)?) + }) + .collect::>>()?; + let token_ids = Tensor::stack(&token_ids, 0)?; + let token_type_ids = token_ids.zeros_like()?; + println!("running inference on batch {:?}", token_ids.shape()); + let embeddings = model.forward(&token_ids, &token_type_ids)?; + println!("generated embeddings {:?}", embeddings.shape()); + // Take the embedding for the first token of each sentence. + // TODO: mean or max pooling? + let embeddings = embeddings.narrow(1, 0, 1)?.squeeze(1)?; + let mut similarities = vec![]; + for i in 0..n_sentences { + let e_i = embeddings.get(i)?; + for j in (i + 1)..n_sentences { + let e_j = embeddings.get(j)?; + let sum_ij = (&e_i * &e_j)?.sum_all()?.reshape(())?.to_scalar::()?; + let sum_i2 = (&e_i * &e_i)?.sum_all()?.reshape(())?.to_scalar::()?; + let sum_j2 = (&e_j * &e_j)?.sum_all()?.reshape(())?.to_scalar::()?; + let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); + similarities.push((cosine_similarity, i, j)) + } + } + similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); + for &(score, i, j) in similarities[..5].iter() { + println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) + } } Ok(()) }