bert attention mask (#1934)

* bert attention mask

* Allow for using None as a mask.

* Revert part of the changes so that the proper default mask applies.

* Cosmetic change.

* Another cosmetic tweak.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Zheng Li
2024-08-01 14:26:19 +08:00
committed by GitHub
parent 24d54d0ff9
commit 4a52aeb437
3 changed files with 53 additions and 20 deletions

View File

@ -126,7 +126,7 @@ fn main() -> Result<()> {
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids, &token_type_ids)?;
let ys = model.forward(&token_ids, &token_type_ids, None)?;
if idx == 0 {
println!("{ys}");
}
@ -163,11 +163,19 @@ fn main() -> Result<()> {
Ok(Tensor::new(tokens.as_slice(), device)?)
})
.collect::<Result<Vec<_>>>()?;
let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Ok(Tensor::new(tokens.as_slice(), device)?)
})
.collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_mask = Tensor::stack(&attention_mask, 0)?;
let token_type_ids = token_ids.zeros_like()?;
println!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids, &token_type_ids)?;
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;