mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use avg pooling before the cosine similarity.
This commit is contained in:
@ -743,9 +743,10 @@ async fn main() -> Result<()> {
|
|||||||
println!("running inference on batch {:?}", token_ids.shape());
|
println!("running inference on batch {:?}", token_ids.shape());
|
||||||
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
||||||
println!("generated embeddings {:?}", embeddings.shape());
|
println!("generated embeddings {:?}", embeddings.shape());
|
||||||
// Take the embedding for the first token of each sentence.
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
// TODO: mean or max pooling?
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
|
||||||
let embeddings = embeddings.narrow(1, 0, 1)?.squeeze(1)?;
|
let embeddings = (embeddings.sum(&[1])? / (n_tokens as f64))?.squeeze(1)?;
|
||||||
|
println!("pooled embeddings {:?}", embeddings.shape());
|
||||||
let mut similarities = vec![];
|
let mut similarities = vec![];
|
||||||
for i in 0..n_sentences {
|
for i in 0..n_sentences {
|
||||||
let e_i = embeddings.get(i)?;
|
let e_i = embeddings.get(i)?;
|
||||||
|
Reference in New Issue
Block a user