Implementing DistilBertForMaskedLM. (#2866)

* Initial commit: model weights working, prediciton incorrect

* moved distilbertformaskedlm into distilbert modeling file

* made maskedLM like bert example, still incorrect predictions

* finally not getting NaNs, fixed attention mask

* getting correct output sentences

* get top k predictions

* fixed output formatting slightly

* added default arg for model_id

* lint

* moved masked token example code from distilbertformaskedlm example to distilbert example

* lint

* removed distilbertformaskedlm example

* cleanup

* clippy

* removed embedding normalization from example

* made output and model dependent on args instead of prompt

* lint

* replaced or_ok anyhow error with anyhow context

* changed error message for mask token not found
This commit is contained in:
Kyle Birnbaum
2025-04-11 04:25:39 -07:00
committed by GitHub
parent d339b01726
commit eb478ece92
3 changed files with 375 additions and 60 deletions

View File

@ -19,7 +19,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
enum HiddenAct {
pub enum HiddenAct {
Gelu,
Relu,
}
@ -49,22 +49,22 @@ impl Module for HiddenActLayer {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
enum PositionEmbeddingType {
pub enum PositionEmbeddingType {
#[default]
Absolute,
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
dim: usize,
pub vocab_size: usize,
pub dim: usize,
n_layers: usize,
n_heads: usize,
hidden_dim: usize,
activation: HiddenAct,
max_position_embeddings: usize,
initializer_range: f64,
pad_token_id: usize,
pub pad_token_id: usize,
#[serde(default)]
position_embedding_type: PositionEmbeddingType,
#[serde(default)]
@ -345,3 +345,107 @@ impl DistilBertModel {
Ok(sequence_output)
}
}
struct DistilBertPredictionHeadTransform {
dense: Linear,
activation: HiddenActLayer,
layer_norm: LayerNorm,
}
impl DistilBertPredictionHeadTransform {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(config.dim, config.dim, vb.pp("vocab_transform"))?;
let activation = HiddenActLayer::new(config.activation);
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("vocab_layer_norm"))?;
Ok(Self {
dense,
activation,
layer_norm,
})
}
}
impl Module for DistilBertPredictionHeadTransform {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let hidden_states = self
.activation
.forward(&self.dense.forward(hidden_states)?)?;
self.layer_norm.forward(&hidden_states)
}
}
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
pub struct DistilBertLMPredictionHead {
transform: DistilBertPredictionHeadTransform,
decoder: Linear,
}
impl DistilBertLMPredictionHead {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let transform = DistilBertPredictionHeadTransform::load(vb.clone(), config)?;
// distil_bert_uncased uses the word embeddings for the vocab projector weight, but has a seperate vocab_projector bias
let vocab_projector_weight_vb = vb.pp("distilbert.embeddings.word_embeddings");
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vocab_projector_weight_vb.get_with_hints(
(config.vocab_size, config.dim),
"weight",
init_ws,
)?;
let bound = 1. / (config.dim as f64).sqrt();
let init_bs = candle_nn::Init::Uniform {
lo: -bound,
up: bound,
};
let vocab_projector_bias_vb = vb.pp("vocab_projector");
let bs = vocab_projector_bias_vb.get_with_hints(config.vocab_size, "bias", init_bs)?;
let decoder = Linear::from_weights(ws, Some(bs));
Ok(Self { transform, decoder })
}
}
impl Module for DistilBertLMPredictionHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
self.decoder
.forward(&self.transform.forward(hidden_states)?)
}
}
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
pub struct DistilBertOnlyMLMHead {
predictions: DistilBertLMPredictionHead,
}
impl DistilBertOnlyMLMHead {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let predictions = DistilBertLMPredictionHead::load(vb.clone(), config)?;
Ok(Self { predictions })
}
}
impl Module for DistilBertOnlyMLMHead {
fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
self.predictions.forward(sequence_output)
}
}
pub struct DistilBertForMaskedLM {
pub bert: DistilBertModel,
cls: DistilBertOnlyMLMHead,
}
impl DistilBertForMaskedLM {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let bert = DistilBertModel::load(vb.pp("distilbert"), config)?;
let cls = DistilBertOnlyMLMHead::load(vb.clone(), config)?;
Ok(Self { bert, cls })
}
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let sequence_output = self.bert.forward(input_ids, attention_mask)?;
self.cls.forward(&sequence_output)
}
}