mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add BertForMaskedLM to support SPLADE Models (#2550)
* add bert for masked lm * working example * add example readme * Clippy fix. * And apply rustfmt. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -504,3 +504,100 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
||||
(attention_mask.ones_like()? - &attention_mask)?
|
||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
||||
}
|
||||
|
||||
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
|
||||
struct BertPredictionHeadTransform {
|
||||
dense: Linear,
|
||||
activation: HiddenActLayer,
|
||||
layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl BertPredictionHeadTransform {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
|
||||
let activation = HiddenActLayer::new(config.hidden_act);
|
||||
let layer_norm = layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("LayerNorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
dense,
|
||||
activation,
|
||||
layer_norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertPredictionHeadTransform {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self
|
||||
.activation
|
||||
.forward(&self.dense.forward(hidden_states)?)?;
|
||||
self.layer_norm.forward(&hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
|
||||
pub struct BertLMPredictionHead {
|
||||
transform: BertPredictionHeadTransform,
|
||||
decoder: Linear,
|
||||
}
|
||||
|
||||
impl BertLMPredictionHead {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let transform = BertPredictionHeadTransform::load(vb.pp("transform"), config)?;
|
||||
let decoder = linear(config.hidden_size, config.vocab_size, vb.pp("decoder"))?;
|
||||
Ok(Self { transform, decoder })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertLMPredictionHead {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
self.decoder
|
||||
.forward(&self.transform.forward(hidden_states)?)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
|
||||
pub struct BertOnlyMLMHead {
|
||||
predictions: BertLMPredictionHead,
|
||||
}
|
||||
|
||||
impl BertOnlyMLMHead {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let predictions = BertLMPredictionHead::load(vb.pp("predictions"), config)?;
|
||||
Ok(Self { predictions })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertOnlyMLMHead {
|
||||
fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
|
||||
self.predictions.forward(sequence_output)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BertForMaskedLM {
|
||||
bert: BertModel,
|
||||
cls: BertOnlyMLMHead,
|
||||
}
|
||||
|
||||
impl BertForMaskedLM {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let bert = BertModel::load(vb.pp("bert"), config)?;
|
||||
let cls = BertOnlyMLMHead::load(vb.pp("cls"), config)?;
|
||||
Ok(Self { bert, cls })
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
token_type_ids: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let sequence_output = self
|
||||
.bert
|
||||
.forward(input_ids, token_type_ids, attention_mask)?;
|
||||
self.cls.forward(&sequence_output)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user