mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
bert attention mask (#1934)
* bert attention mask * Allow for using None as a mask. * Revert part of the changes so that the proper default mask applies. * Cosmetic change. * Another cosmetic tweak. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -55,11 +55,21 @@ impl Model {
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let attention_mask: Vec<Tensor> = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_attention_mask().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
let attention_mask = Tensor::stack(&attention_mask, 0)?;
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
console_log!("running inference on batch {:?}", token_ids.shape());
|
||||
let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
|
||||
let embeddings = self
|
||||
.bert
|
||||
.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
|
||||
console_log!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
|
Reference in New Issue
Block a user