bert attention mask (#1934)

* bert attention mask

* Allow for using None as a mask.

* Revert part of the changes so that the proper default mask applies.

* Cosmetic change.

* Another cosmetic tweak.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Zheng Li
2024-08-01 14:26:19 +08:00
committed by GitHub
parent 24d54d0ff9
commit 4a52aeb437
3 changed files with 53 additions and 20 deletions

View File

@ -55,11 +55,21 @@ impl Model {
Tensor::new(tokens.as_slice(), device)
})
.collect::<Result<Vec<_>, _>>()?;
let attention_mask: Vec<Tensor> = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<Result<Vec<_>, _>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_mask = Tensor::stack(&attention_mask, 0)?;
let token_type_ids = token_ids.zeros_like()?;
console_log!("running inference on batch {:?}", token_ids.shape());
let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
let embeddings = self
.bert
.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
console_log!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;