mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add some sentence similarity comparision to the bert example.
This commit is contained in:
@ -5,6 +5,7 @@ use candle_hub::{api::Api, Cache, Repo, RepoType};
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
@ -578,6 +579,7 @@ impl BertEncoder {
|
|||||||
struct BertModel {
|
struct BertModel {
|
||||||
embeddings: BertEmbeddings,
|
embeddings: BertEmbeddings,
|
||||||
encoder: BertEncoder,
|
encoder: BertEncoder,
|
||||||
|
device: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BertModel {
|
impl BertModel {
|
||||||
@ -600,6 +602,7 @@ impl BertModel {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
embeddings,
|
embeddings,
|
||||||
encoder,
|
encoder,
|
||||||
|
device: vb.device.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -628,30 +631,25 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
/// The number of times to run the prompt.
|
/// When set, compute embeddings for this prompt.
|
||||||
#[arg(long, default_value = "This is an example sentence")]
|
#[arg(long)]
|
||||||
prompt: String,
|
prompt: Option<String>,
|
||||||
|
|
||||||
/// The number of times to run the prompt.
|
/// The number of times to run the prompt.
|
||||||
#[arg(long, default_value = "1")]
|
#[arg(long, default_value = "1")]
|
||||||
n: usize,
|
n: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
impl Args {
|
||||||
async fn main() -> Result<()> {
|
async fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
|
||||||
use tokenizers::Tokenizer;
|
let device = if self.cpu {
|
||||||
let start = std::time::Instant::now();
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let device = if args.cpu {
|
|
||||||
Device::Cpu
|
Device::Cpu
|
||||||
} else {
|
} else {
|
||||||
Device::new_cuda(0)?
|
Device::new_cuda(0)?
|
||||||
};
|
};
|
||||||
|
|
||||||
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
|
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
|
||||||
let default_revision = "refs/pr/21".to_string();
|
let default_revision = "refs/pr/21".to_string();
|
||||||
let (model_id, revision) = match (args.model_id, args.revision) {
|
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
(None, Some(revision)) => (default_model, revision),
|
(None, Some(revision)) => (default_model, revision),
|
||||||
@ -659,7 +657,7 @@ async fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
|
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
||||||
let cache = Cache::default();
|
let cache = Cache::default();
|
||||||
(
|
(
|
||||||
cache
|
cache
|
||||||
@ -682,27 +680,88 @@ async fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let config: Config = serde_json::from_str(&config)?;
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
|
||||||
let model = BertModel::load(&vb, &config)?;
|
let model = BertModel::load(&vb, &config)?;
|
||||||
|
Ok((model, tokenizer))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let (model, mut tokenizer) = args.build_model_and_tokenizer().await?;
|
||||||
|
let device = &model.device;
|
||||||
|
|
||||||
|
if let Some(prompt) = args.prompt {
|
||||||
|
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||||
let tokens = tokenizer
|
let tokens = tokenizer
|
||||||
.encode(args.prompt, true)
|
.encode(prompt, true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
let token_type_ids = token_ids.zeros_like()?;
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
println!("Loaded and encoded {:?}", start.elapsed());
|
println!("Loaded and encoded {:?}", start.elapsed());
|
||||||
for _ in 0..args.n {
|
for _ in 0..args.n {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let _ys = model.forward(&token_ids, &token_type_ids)?;
|
let _ys = model.forward(&token_ids, &token_type_ids)?;
|
||||||
println!("Took {:?}", start.elapsed());
|
println!("Took {:?}", start.elapsed());
|
||||||
// println!("Ys {:?}", ys.shape());
|
}
|
||||||
|
} else {
|
||||||
|
let sentences = [
|
||||||
|
"The cat sits outside",
|
||||||
|
"A man is playing guitar",
|
||||||
|
"I love pasta",
|
||||||
|
"The new movie is awesome",
|
||||||
|
"The cat plays in the garden",
|
||||||
|
"A woman watches TV",
|
||||||
|
"The new movie is so great",
|
||||||
|
"Do you like pizza?",
|
||||||
|
];
|
||||||
|
let n_sentences = sentences.len();
|
||||||
|
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||||
|
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||||
|
}
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode_batch(sentences.to_vec(), true)
|
||||||
|
.map_err(E::msg)?;
|
||||||
|
let token_ids = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|tokens| {
|
||||||
|
let tokens = tokens.get_ids().to_vec();
|
||||||
|
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||||
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
|
println!("running inference on batch {:?}", token_ids.shape());
|
||||||
|
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
||||||
|
println!("generated embeddings {:?}", embeddings.shape());
|
||||||
|
// Take the embedding for the first token of each sentence.
|
||||||
|
// TODO: mean or max pooling?
|
||||||
|
let embeddings = embeddings.narrow(1, 0, 1)?.squeeze(1)?;
|
||||||
|
let mut similarities = vec![];
|
||||||
|
for i in 0..n_sentences {
|
||||||
|
let e_i = embeddings.get(i)?;
|
||||||
|
for j in (i + 1)..n_sentences {
|
||||||
|
let e_j = embeddings.get(j)?;
|
||||||
|
let sum_ij = (&e_i * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
|
||||||
|
let sum_i2 = (&e_i * &e_i)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
|
||||||
|
let sum_j2 = (&e_j * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
|
||||||
|
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||||
|
similarities.push((cosine_similarity, i, j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
||||||
|
for &(score, i, j) in similarities[..5].iter() {
|
||||||
|
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user