Use avg pooling before the cosine similarity.

This commit is contained in:
laurent
2023-07-05 17:05:50 +01:00
parent a4a60a13fa
commit 174e57d216

View File

@ -743,9 +743,10 @@ async fn main() -> Result<()> {
println!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids, &token_type_ids)?;
println!("generated embeddings {:?}", embeddings.shape());
// Take the embedding for the first token of each sentence.
// TODO: mean or max pooling?
let embeddings = embeddings.narrow(1, 0, 1)?.squeeze(1)?;
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
let embeddings = (embeddings.sum(&[1])? / (n_tokens as f64))?.squeeze(1)?;
println!("pooled embeddings {:?}", embeddings.shape());
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = embeddings.get(i)?;