Simplify the parameters used by sum and sum_keepdim. (#165)

This commit is contained in:
Laurent Mazare
2023-07-14 08:22:08 +01:00
committed by GitHub
parent 2bfa791336
commit a2f72edc0d
13 changed files with 179 additions and 98 deletions

View File

@ -604,7 +604,7 @@ fn main() -> Result<()> {
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
let embeddings = (embeddings.sum_keepdim(&[1])? / (n_tokens as f64))?.squeeze(1)?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
println!("pooled embeddings {:?}", embeddings.shape());
let mut similarities = vec![];
for i in 0..n_sentences {