mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Use the same default as pytorch for sum. (#164)
This commit is contained in:
@ -604,16 +604,16 @@ fn main() -> Result<()> {
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
|
||||
let embeddings = (embeddings.sum(&[1])? / (n_tokens as f64))?.squeeze(1)?;
|
||||
let embeddings = (embeddings.sum_keepdim(&[1])? / (n_tokens as f64))?.squeeze(1)?;
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
let mut similarities = vec![];
|
||||
for i in 0..n_sentences {
|
||||
let e_i = embeddings.get(i)?;
|
||||
for j in (i + 1)..n_sentences {
|
||||
let e_j = embeddings.get(j)?;
|
||||
let sum_ij = (&e_i * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
|
||||
let sum_i2 = (&e_i * &e_i)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
|
||||
let sum_j2 = (&e_j * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
|
||||
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||
similarities.push((cosine_similarity, i, j))
|
||||
}
|
||||
|
Reference in New Issue
Block a user