mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Remove the padding. (#838)
This commit is contained in:
@ -132,28 +132,14 @@ fn main() -> Result<()> {
|
|||||||
"Do you like pizza?",
|
"Do you like pizza?",
|
||||||
];
|
];
|
||||||
let n_sentences = sentences.len();
|
let n_sentences = sentences.len();
|
||||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
let mut all_embeddings = Vec::with_capacity(n_sentences);
|
||||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
for sentence in sentences {
|
||||||
} else {
|
|
||||||
let pp = tokenizers::PaddingParams {
|
|
||||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
tokenizer.with_padding(Some(pp));
|
|
||||||
}
|
|
||||||
let tokens = tokenizer
|
let tokens = tokenizer
|
||||||
.encode_batch(sentences.to_vec(), true)
|
.encode(sentence, true)
|
||||||
.map_err(E::msg)?;
|
.map_err(E::msg)?
|
||||||
let token_ids = tokens
|
.get_ids()
|
||||||
.iter()
|
.to_vec();
|
||||||
.map(|tokens| {
|
let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
|
||||||
let tokens = tokens.get_ids().to_vec();
|
|
||||||
Ok(Tensor::new(tokens.as_slice(), model.device())?)
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
|
|
||||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
|
||||||
println!("running inference on batch {:?}", token_ids.shape());
|
|
||||||
let embeddings = model.forward(&token_ids)?;
|
let embeddings = model.forward(&token_ids)?;
|
||||||
println!("generated embeddings {:?}", embeddings.shape());
|
println!("generated embeddings {:?}", embeddings.shape());
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
@ -165,15 +151,20 @@ fn main() -> Result<()> {
|
|||||||
embeddings
|
embeddings
|
||||||
};
|
};
|
||||||
println!("pooled embeddings {:?}", embeddings.shape());
|
println!("pooled embeddings {:?}", embeddings.shape());
|
||||||
|
all_embeddings.push(embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
let mut similarities = vec![];
|
let mut similarities = vec![];
|
||||||
for i in 0..n_sentences {
|
for (i, e_i) in all_embeddings.iter().enumerate() {
|
||||||
let e_i = embeddings.get(i)?;
|
for (j, e_j) in all_embeddings
|
||||||
for j in (i + 1)..n_sentences {
|
.iter()
|
||||||
let e_j = embeddings.get(j)?;
|
.enumerate()
|
||||||
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
.take(n_sentences)
|
||||||
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
.skip(i + 1)
|
||||||
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
{
|
||||||
|
let sum_ij = (e_i * e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let sum_i2 = (e_i * e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let sum_j2 = (e_j * e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||||
similarities.push((cosine_similarity, i, j))
|
similarities.push((cosine_similarity, i, j))
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user