Add some sentence similarity comparision to the bert example.

This commit is contained in:
laurent
2023-07-05 16:49:57 +01:00
parent 4e80319147
commit 914e84deec

View File

@ -5,6 +5,7 @@ use candle_hub::{api::Api, Cache, Repo, RepoType};
use clap::Parser;
use serde::Deserialize;
use std::collections::HashMap;
use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32;
@ -578,6 +579,7 @@ impl BertEncoder {
struct BertModel {
embeddings: BertEmbeddings,
encoder: BertEncoder,
device: Device,
}
impl BertModel {
@ -600,6 +602,7 @@ impl BertModel {
Ok(Self {
embeddings,
encoder,
device: vb.device.clone(),
})
}
@ -628,30 +631,25 @@ struct Args {
#[arg(long)]
revision: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "This is an example sentence")]
prompt: String,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
}
#[tokio::main]
async fn main() -> Result<()> {
use tokenizers::Tokenizer;
let start = std::time::Instant::now();
let args = Args::parse();
let device = if args.cpu {
impl Args {
async fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
let device = if self.cpu {
Device::Cpu
} else {
Device::new_cuda(0)?
};
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
let default_revision = "refs/pr/21".to_string();
let (model_id, revision) = match (args.model_id, args.revision) {
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
@ -659,7 +657,7 @@ async fn main() -> Result<()> {
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
let cache = Cache::default();
(
cache
@ -682,27 +680,88 @@ async fn main() -> Result<()> {
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
let model = BertModel::load(&vb, &config)?;
Ok((model, tokenizer))
}
}
#[tokio::main]
async fn main() -> Result<()> {
let start = std::time::Instant::now();
let args = Args::parse();
let (model, mut tokenizer) = args.build_model_and_tokenizer().await?;
let device = &model.device;
if let Some(prompt) = args.prompt {
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
let tokens = tokenizer
.encode(args.prompt, true)
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
println!("Loaded and encoded {:?}", start.elapsed());
for _ in 0..args.n {
let start = std::time::Instant::now();
let _ys = model.forward(&token_ids, &token_type_ids)?;
println!("Took {:?}", start.elapsed());
// println!("Ys {:?}", ys.shape());
}
} else {
let sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
];
let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
}
let tokens = tokenizer
.encode_batch(sentences.to_vec(), true)
.map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Ok(Tensor::new(tokens.as_slice(), device)?)
})
.collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let token_type_ids = token_ids.zeros_like()?;
println!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids, &token_type_ids)?;
println!("generated embeddings {:?}", embeddings.shape());
// Take the embedding for the first token of each sentence.
// TODO: mean or max pooling?
let embeddings = embeddings.narrow(1, 0, 1)?.squeeze(1)?;
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = embeddings.get(i)?;
for j in (i + 1)..n_sentences {
let e_j = embeddings.get(j)?;
let sum_ij = (&e_i * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
let sum_i2 = (&e_i * &e_i)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
let sum_j2 = (&e_j * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}
}
Ok(())
}