mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Implementing DistilBertForMaskedLM. (#2866)
* Initial commit: model weights working, prediciton incorrect * moved distilbertformaskedlm into distilbert modeling file * made maskedLM like bert example, still incorrect predictions * finally not getting NaNs, fixed attention mask * getting correct output sentences * get top k predictions * fixed output formatting slightly * added default arg for model_id * lint * moved masked token example code from distilbertformaskedlm example to distilbert example * lint * removed distilbertformaskedlm example * cleanup * clippy * removed embedding normalization from example * made output and model dependent on args instead of prompt * lint * replaced or_ok anyhow error with anyhow context * changed error message for mask token not found
This commit is contained in:
@ -19,7 +19,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum HiddenAct {
|
||||
pub enum HiddenAct {
|
||||
Gelu,
|
||||
Relu,
|
||||
}
|
||||
@ -49,22 +49,22 @@ impl Module for HiddenActLayer {
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum PositionEmbeddingType {
|
||||
pub enum PositionEmbeddingType {
|
||||
#[default]
|
||||
Absolute,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
dim: usize,
|
||||
pub vocab_size: usize,
|
||||
pub dim: usize,
|
||||
n_layers: usize,
|
||||
n_heads: usize,
|
||||
hidden_dim: usize,
|
||||
activation: HiddenAct,
|
||||
max_position_embeddings: usize,
|
||||
initializer_range: f64,
|
||||
pad_token_id: usize,
|
||||
pub pad_token_id: usize,
|
||||
#[serde(default)]
|
||||
position_embedding_type: PositionEmbeddingType,
|
||||
#[serde(default)]
|
||||
@ -345,3 +345,107 @@ impl DistilBertModel {
|
||||
Ok(sequence_output)
|
||||
}
|
||||
}
|
||||
|
||||
struct DistilBertPredictionHeadTransform {
|
||||
dense: Linear,
|
||||
activation: HiddenActLayer,
|
||||
layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl DistilBertPredictionHeadTransform {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear(config.dim, config.dim, vb.pp("vocab_transform"))?;
|
||||
let activation = HiddenActLayer::new(config.activation);
|
||||
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("vocab_layer_norm"))?;
|
||||
Ok(Self {
|
||||
dense,
|
||||
activation,
|
||||
layer_norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DistilBertPredictionHeadTransform {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self
|
||||
.activation
|
||||
.forward(&self.dense.forward(hidden_states)?)?;
|
||||
self.layer_norm.forward(&hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
|
||||
pub struct DistilBertLMPredictionHead {
|
||||
transform: DistilBertPredictionHeadTransform,
|
||||
decoder: Linear,
|
||||
}
|
||||
|
||||
impl DistilBertLMPredictionHead {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let transform = DistilBertPredictionHeadTransform::load(vb.clone(), config)?;
|
||||
|
||||
// distil_bert_uncased uses the word embeddings for the vocab projector weight, but has a seperate vocab_projector bias
|
||||
let vocab_projector_weight_vb = vb.pp("distilbert.embeddings.word_embeddings");
|
||||
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vocab_projector_weight_vb.get_with_hints(
|
||||
(config.vocab_size, config.dim),
|
||||
"weight",
|
||||
init_ws,
|
||||
)?;
|
||||
let bound = 1. / (config.dim as f64).sqrt();
|
||||
let init_bs = candle_nn::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
|
||||
let vocab_projector_bias_vb = vb.pp("vocab_projector");
|
||||
let bs = vocab_projector_bias_vb.get_with_hints(config.vocab_size, "bias", init_bs)?;
|
||||
|
||||
let decoder = Linear::from_weights(ws, Some(bs));
|
||||
|
||||
Ok(Self { transform, decoder })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DistilBertLMPredictionHead {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
self.decoder
|
||||
.forward(&self.transform.forward(hidden_states)?)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
|
||||
pub struct DistilBertOnlyMLMHead {
|
||||
predictions: DistilBertLMPredictionHead,
|
||||
}
|
||||
|
||||
impl DistilBertOnlyMLMHead {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let predictions = DistilBertLMPredictionHead::load(vb.clone(), config)?;
|
||||
Ok(Self { predictions })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DistilBertOnlyMLMHead {
|
||||
fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
|
||||
self.predictions.forward(sequence_output)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DistilBertForMaskedLM {
|
||||
pub bert: DistilBertModel,
|
||||
cls: DistilBertOnlyMLMHead,
|
||||
}
|
||||
|
||||
impl DistilBertForMaskedLM {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let bert = DistilBertModel::load(vb.pp("distilbert"), config)?;
|
||||
let cls = DistilBertOnlyMLMHead::load(vb.clone(), config)?;
|
||||
Ok(Self { bert, cls })
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let sequence_output = self.bert.forward(input_ids, attention_mask)?;
|
||||
self.cls.forward(&sequence_output)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user