Merge pull request #79 from LaurentMazare/bert-similarities

Add some sentence similarity comparison task to the bert example.
This commit is contained in:
Laurent Mazare
2023-07-05 16:51:25 +01:00
committed by GitHub

View File

@ -5,6 +5,7 @@ use candle_hub::{api::Api, Cache, Repo, RepoType};
use clap::Parser; use clap::Parser;
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32; const DTYPE: DType = DType::F32;
@ -578,6 +579,7 @@ impl BertEncoder {
struct BertModel { struct BertModel {
embeddings: BertEmbeddings, embeddings: BertEmbeddings,
encoder: BertEncoder, encoder: BertEncoder,
device: Device,
} }
impl BertModel { impl BertModel {
@ -600,6 +602,7 @@ impl BertModel {
Ok(Self { Ok(Self {
embeddings, embeddings,
encoder, encoder,
device: vb.device.clone(),
}) })
} }
@ -628,81 +631,137 @@ struct Args {
#[arg(long)] #[arg(long)]
revision: Option<String>, revision: Option<String>,
/// The number of times to run the prompt. /// When set, compute embeddings for this prompt.
#[arg(long, default_value = "This is an example sentence")] #[arg(long)]
prompt: String, prompt: Option<String>,
/// The number of times to run the prompt. /// The number of times to run the prompt.
#[arg(long, default_value = "1")] #[arg(long, default_value = "1")]
n: usize, n: usize,
} }
impl Args {
async fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
let device = if self.cpu {
Device::Cpu
} else {
Device::new_cuda(0)?
};
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
let default_revision = "refs/pr/21".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
let cache = Cache::default();
(
cache
.get(&repo, "config.json")
.ok_or(anyhow!("Missing config file in cache"))?,
cache
.get(&repo, "tokenizer.json")
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
cache
.get(&repo, "model.safetensors")
.ok_or(anyhow!("Missing weights file in cache"))?,
)
} else {
let api = Api::new()?;
(
api.get(&repo, "config.json").await?,
api.get(&repo, "tokenizer.json").await?,
api.get(&repo, "model.safetensors").await?,
)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
let model = BertModel::load(&vb, &config)?;
Ok((model, tokenizer))
}
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
use tokenizers::Tokenizer;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let args = Args::parse(); let args = Args::parse();
let device = if args.cpu { let (model, mut tokenizer) = args.build_model_and_tokenizer().await?;
Device::Cpu let device = &model.device;
if let Some(prompt) = args.prompt {
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
println!("Loaded and encoded {:?}", start.elapsed());
for _ in 0..args.n {
let start = std::time::Instant::now();
let _ys = model.forward(&token_ids, &token_type_ids)?;
println!("Took {:?}", start.elapsed());
}
} else { } else {
Device::new_cuda(0)? let sentences = [
}; "The cat sits outside",
"A man is playing guitar",
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); "I love pasta",
let default_revision = "refs/pr/21".to_string(); "The new movie is awesome",
let (model_id, revision) = match (args.model_id, args.revision) { "The cat plays in the garden",
(Some(model_id), Some(revision)) => (model_id, revision), "A woman watches TV",
(Some(model_id), None) => (model_id, "main".to_string()), "The new movie is so great",
(None, Some(revision)) => (default_model, revision), "Do you like pizza?",
(None, None) => (default_model, default_revision), ];
}; let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() {
let repo = Repo::with_revision(model_id, RepoType::Model, revision); pp.strategy = tokenizers::PaddingStrategy::BatchLongest
let (config_filename, tokenizer_filename, weights_filename) = if args.offline { }
let cache = Cache::default(); let tokens = tokenizer
( .encode_batch(sentences.to_vec(), true)
cache .map_err(E::msg)?;
.get(&repo, "config.json") let token_ids = tokens
.ok_or(anyhow!("Missing config file in cache"))?, .iter()
cache .map(|tokens| {
.get(&repo, "tokenizer.json") let tokens = tokens.get_ids().to_vec();
.ok_or(anyhow!("Missing tokenizer file in cache"))?, Ok(Tensor::new(tokens.as_slice(), device)?)
cache })
.get(&repo, "model.safetensors") .collect::<Result<Vec<_>>>()?;
.ok_or(anyhow!("Missing weights file in cache"))?, let token_ids = Tensor::stack(&token_ids, 0)?;
) let token_type_ids = token_ids.zeros_like()?;
} else { println!("running inference on batch {:?}", token_ids.shape());
let api = Api::new()?; let embeddings = model.forward(&token_ids, &token_type_ids)?;
( println!("generated embeddings {:?}", embeddings.shape());
api.get(&repo, "config.json").await?, // Take the embedding for the first token of each sentence.
api.get(&repo, "tokenizer.json").await?, // TODO: mean or max pooling?
api.get(&repo, "model.safetensors").await?, let embeddings = embeddings.narrow(1, 0, 1)?.squeeze(1)?;
) let mut similarities = vec![];
}; for i in 0..n_sentences {
let config = std::fs::read_to_string(config_filename)?; let e_i = embeddings.get(i)?;
let config: Config = serde_json::from_str(&config)?; for j in (i + 1)..n_sentences {
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let e_j = embeddings.get(j)?;
let tokenizer = tokenizer.with_padding(None).with_truncation(None); let sum_ij = (&e_i * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
let sum_i2 = (&e_i * &e_i)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let sum_j2 = (&e_j * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
let weights = weights.deserialize()?; let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone()); similarities.push((cosine_similarity, i, j))
let model = BertModel::load(&vb, &config)?; }
}
let tokens = tokenizer similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
.encode(args.prompt, true) for &(score, i, j) in similarities[..5].iter() {
.map_err(E::msg)? println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
.get_ids() }
.to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
println!("Loaded and encoded {:?}", start.elapsed());
for _ in 0..args.n {
let start = std::time::Instant::now();
let _ys = model.forward(&token_ids, &token_type_ids)?;
println!("Took {:?}", start.elapsed());
// println!("Ys {:?}", ys.shape());
} }
Ok(()) Ok(())
} }