mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Normalize embeddings in the bert example. (#390)
This commit is contained in:
@ -39,6 +39,10 @@ struct Args {
|
|||||||
/// The number of times to run the prompt.
|
/// The number of times to run the prompt.
|
||||||
#[arg(long, default_value = "1")]
|
#[arg(long, default_value = "1")]
|
||||||
n: usize,
|
n: usize,
|
||||||
|
|
||||||
|
/// L2 normalization for embeddings.
|
||||||
|
#[arg(long, default_value = "true")]
|
||||||
|
normalize_embeddings: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -164,7 +168,13 @@ fn main() -> Result<()> {
|
|||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||||
|
let embeddings = if args.normalize_embeddings {
|
||||||
|
normalize_l2(&embeddings)?
|
||||||
|
} else {
|
||||||
|
embeddings
|
||||||
|
};
|
||||||
println!("pooled embeddings {:?}", embeddings.shape());
|
println!("pooled embeddings {:?}", embeddings.shape());
|
||||||
|
|
||||||
let mut similarities = vec![];
|
let mut similarities = vec![];
|
||||||
for i in 0..n_sentences {
|
for i in 0..n_sentences {
|
||||||
let e_i = embeddings.get(i)?;
|
let e_i = embeddings.get(i)?;
|
||||||
@ -184,3 +194,7 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||||
|
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user