Add BertForMaskedLM to support SPLADE Models (#2550)

* add bert for masked lm

* working example

* add example readme

* Clippy fix.

* And apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Akshay Ballal
2024-10-07 23:28:21 +02:00
committed by GitHub
parent edf7668291
commit 937e8eda74
3 changed files with 335 additions and 0 deletions

View File

@ -504,3 +504,100 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
(attention_mask.ones_like()? - &attention_mask)?
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
}
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
struct BertPredictionHeadTransform {
dense: Linear,
activation: HiddenActLayer,
layer_norm: LayerNorm,
}
impl BertPredictionHeadTransform {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
let activation = HiddenActLayer::new(config.hidden_act);
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
vb.pp("LayerNorm"),
)?;
Ok(Self {
dense,
activation,
layer_norm,
})
}
}
impl Module for BertPredictionHeadTransform {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let hidden_states = self
.activation
.forward(&self.dense.forward(hidden_states)?)?;
self.layer_norm.forward(&hidden_states)
}
}
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
pub struct BertLMPredictionHead {
transform: BertPredictionHeadTransform,
decoder: Linear,
}
impl BertLMPredictionHead {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let transform = BertPredictionHeadTransform::load(vb.pp("transform"), config)?;
let decoder = linear(config.hidden_size, config.vocab_size, vb.pp("decoder"))?;
Ok(Self { transform, decoder })
}
}
impl Module for BertLMPredictionHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
self.decoder
.forward(&self.transform.forward(hidden_states)?)
}
}
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
pub struct BertOnlyMLMHead {
predictions: BertLMPredictionHead,
}
impl BertOnlyMLMHead {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let predictions = BertLMPredictionHead::load(vb.pp("predictions"), config)?;
Ok(Self { predictions })
}
}
impl Module for BertOnlyMLMHead {
fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
self.predictions.forward(sequence_output)
}
}
pub struct BertForMaskedLM {
bert: BertModel,
cls: BertOnlyMLMHead,
}
impl BertForMaskedLM {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let bert = BertModel::load(vb.pp("bert"), config)?;
let cls = BertOnlyMLMHead::load(vb.pp("cls"), config)?;
Ok(Self { bert, cls })
}
pub fn forward(
&self,
input_ids: &Tensor,
token_type_ids: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let sequence_output = self
.bert
.forward(input_ids, token_type_ids, attention_mask)?;
self.cls.forward(&sequence_output)
}
}