mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Normalize embeddings in the bert example. (#390)
This commit is contained in:
@ -39,6 +39,10 @@ struct Args {
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -164,7 +168,13 @@ fn main() -> Result<()> {
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if args.normalize_embeddings {
|
||||
normalize_l2(&embeddings)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
|
||||
let mut similarities = vec![];
|
||||
for i in 0..n_sentences {
|
||||
let e_i = embeddings.get(i)?;
|
||||
@ -184,3 +194,7 @@ fn main() -> Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||
}
|
||||
|
Reference in New Issue
Block a user