Remove the padding. (#838)

This commit is contained in:
Laurent Mazare
2023-09-13 14:00:59 +02:00
committed by GitHub
parent b11a2a7b9d
commit 31ab2ddaeb

View File

@ -132,28 +132,14 @@ fn main() -> Result<()> {
"Do you like pizza?", "Do you like pizza?",
]; ];
let n_sentences = sentences.len(); let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() { let mut all_embeddings = Vec::with_capacity(n_sentences);
pp.strategy = tokenizers::PaddingStrategy::BatchLongest for sentence in sentences {
} else {
let pp = tokenizers::PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let tokens = tokenizer let tokens = tokenizer
.encode_batch(sentences.to_vec(), true) .encode(sentence, true)
.map_err(E::msg)?; .map_err(E::msg)?
let token_ids = tokens .get_ids()
.iter() .to_vec();
.map(|tokens| { let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
let tokens = tokens.get_ids().to_vec();
Ok(Tensor::new(tokens.as_slice(), model.device())?)
})
.collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
println!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids)?; let embeddings = model.forward(&token_ids)?;
println!("generated embeddings {:?}", embeddings.shape()); println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
@ -165,15 +151,20 @@ fn main() -> Result<()> {
embeddings embeddings
}; };
println!("pooled embeddings {:?}", embeddings.shape()); println!("pooled embeddings {:?}", embeddings.shape());
all_embeddings.push(embeddings)
}
let mut similarities = vec![]; let mut similarities = vec![];
for i in 0..n_sentences { for (i, e_i) in all_embeddings.iter().enumerate() {
let e_i = embeddings.get(i)?; for (j, e_j) in all_embeddings
for j in (i + 1)..n_sentences { .iter()
let e_j = embeddings.get(j)?; .enumerate()
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?; .take(n_sentences)
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?; .skip(i + 1)
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?; {
let sum_ij = (e_i * e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (e_i * e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (e_j * e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j)) similarities.push((cosine_similarity, i, j))
} }