mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
bert attention mask (#1934)
* bert attention mask * Allow for using None as a mask. * Revert part of the changes so that the proper default mask applies. * Cosmetic change. * Another cosmetic tweak. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -126,7 +126,7 @@ fn main() -> Result<()> {
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
||||
let ys = model.forward(&token_ids, &token_type_ids, None)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
@ -163,11 +163,19 @@ fn main() -> Result<()> {
|
||||
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let attention_mask = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_attention_mask().to_vec();
|
||||
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
let attention_mask = Tensor::stack(&attention_mask, 0)?;
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
println!("running inference on batch {:?}", token_ids.shape());
|
||||
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
||||
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
|
Reference in New Issue
Block a user