Normalize embeddings in the bert example. (#390)

This commit is contained in:
Laurent Mazare
2023-08-10 14:05:55 +02:00
committed by GitHub
parent b765f2c37f
commit 385f0d261c

View File

@ -39,6 +39,10 @@ struct Args {
/// The number of times to run the prompt. /// The number of times to run the prompt.
#[arg(long, default_value = "1")] #[arg(long, default_value = "1")]
n: usize, n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
} }
impl Args { impl Args {
@ -164,7 +168,13 @@ fn main() -> Result<()> {
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if args.normalize_embeddings {
normalize_l2(&embeddings)?
} else {
embeddings
};
println!("pooled embeddings {:?}", embeddings.shape()); println!("pooled embeddings {:?}", embeddings.shape());
let mut similarities = vec![]; let mut similarities = vec![];
for i in 0..n_sentences { for i in 0..n_sentences {
let e_i = embeddings.get(i)?; let e_i = embeddings.get(i)?;
@ -184,3 +194,7 @@ fn main() -> Result<()> {
} }
Ok(()) Ok(())
} }
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}