Add ModernBert sentency classifier (#2796)

This commit is contained in:
Mikhail Panfilov
2025-03-08 16:48:22 +03:00
committed by GitHub
parent 37db86ff79
commit e4ffb85228

View File

@ -6,14 +6,15 @@
//! - See modernbert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code
//!
use candle::{DType, Device, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm_no_bias, linear_no_bias, ops::softmax, Embedding, LayerNorm, Linear,
Module, VarBuilder,
embedding, layer_norm_no_bias, linear, linear_no_bias, ops::softmax, Embedding, LayerNorm,
Linear, Module, VarBuilder,
};
use serde::Deserialize;
use core::f32;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Deserialize)]
@ -30,6 +31,24 @@ pub struct Config {
pub global_rope_theta: f64,
pub local_attention: usize,
pub local_rope_theta: f64,
#[serde(default)]
#[serde(flatten)]
pub classifier_config: Option<ClassifierConfig>,
}
#[derive(Debug, Clone, Deserialize, PartialEq, Copy, Default)]
#[serde(rename_all = "lowercase")]
pub enum ClassifierPooling {
#[default]
CLS,
MEAN,
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct ClassifierConfig {
pub id2label: HashMap<String, String>,
pub label2id: HashMap<String, String>,
pub classifier_pooling: ClassifierPooling,
}
#[derive(Debug, Clone)]
@ -310,7 +329,6 @@ pub struct ModernBert {
norm: LayerNorm,
layers: Vec<ModernBertLayer>,
final_norm: LayerNorm,
head: ModernBertHead,
local_attention_size: usize,
}
@ -359,14 +377,12 @@ impl ModernBert {
config.layer_norm_eps,
vb.pp("model.final_norm"),
)?;
let head = ModernBertHead::load(vb.pp("head"), config)?;
Ok(Self {
word_embeddings,
norm,
layers,
final_norm,
head,
local_attention_size: config.local_attention,
})
}
@ -381,7 +397,7 @@ impl ModernBert {
for layer in self.layers.iter() {
xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?;
}
let xs = xs.apply(&self.final_norm)?.apply(&self.head)?;
let xs = xs.apply(&self.final_norm)?;
Ok(xs)
}
}
@ -391,17 +407,98 @@ impl ModernBert {
pub struct ModernBertForMaskedLM {
model: ModernBert,
decoder: ModernBertDecoder,
head: ModernBertHead,
}
impl ModernBertForMaskedLM {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let model = ModernBert::load(vb.clone(), config)?;
let decoder = ModernBertDecoder::load(vb.clone(), config)?;
Ok(Self { model, decoder })
let head = ModernBertHead::load(vb.pp("head"), config)?;
Ok(Self {
model,
decoder,
head,
})
}
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
let xs = self.model.forward(xs, mask)?.apply(&self.decoder)?;
let xs = self
.model
.forward(xs, mask)?
.apply(&self.head)?
.apply(&self.decoder)?;
Ok(xs)
}
}
#[derive(Clone)]
pub struct ModernBertClassifier {
classifier: Linear,
}
impl ModernBertClassifier {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
// The decoder weights are tied with the embeddings layer weights
let classifier = linear(
config.hidden_size,
config
.classifier_config
.as_ref()
.map(|cc| cc.id2label.len())
.unwrap_or_default(),
vb.pp("classifier"),
)?;
Ok(Self { classifier })
}
}
impl Module for ModernBertClassifier {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.classifier)?;
softmax(&xs, D::Minus1)
}
}
#[derive(Clone)]
pub struct ModernBertForSequenceClassification {
model: ModernBert,
head: ModernBertHead,
classifier: ModernBertClassifier,
classifier_pooling: ClassifierPooling,
}
impl ModernBertForSequenceClassification {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let model = ModernBert::load(vb.clone(), config)?;
let classifier = ModernBertClassifier::load(vb.clone(), config)?;
let head = ModernBertHead::load(vb.pp("head"), config)?;
Ok(Self {
model,
head,
classifier,
classifier_pooling: config
.classifier_config
.as_ref()
.map(|cc| cc.classifier_pooling)
.unwrap_or_default(),
})
}
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
let output = self.model.forward(xs, mask)?;
let last_hidden_state = match self.classifier_pooling {
ClassifierPooling::CLS => output.i((.., .., 0))?,
ClassifierPooling::MEAN => {
let unsqueezed_mask = &mask.unsqueeze(D::Minus1)?.to_dtype(DType::F32)?;
let sum_output = output.broadcast_mul(unsqueezed_mask)?.sum(1)?;
sum_output.broadcast_div(&mask.sum_keepdim(1)?.to_dtype(DType::F32)?)?
}
};
let xs = self
.head
.forward(&last_hidden_state)?
.apply(&self.classifier)?;
Ok(xs)
}
}