mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add ModernBert sentency classifier (#2796)
This commit is contained in:
@ -6,14 +6,15 @@
|
||||
//! - See modernbert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code
|
||||
//!
|
||||
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm_no_bias, linear_no_bias, ops::softmax, Embedding, LayerNorm, Linear,
|
||||
Module, VarBuilder,
|
||||
embedding, layer_norm_no_bias, linear, linear_no_bias, ops::softmax, Embedding, LayerNorm,
|
||||
Linear, Module, VarBuilder,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use core::f32;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
@ -30,6 +31,24 @@ pub struct Config {
|
||||
pub global_rope_theta: f64,
|
||||
pub local_attention: usize,
|
||||
pub local_rope_theta: f64,
|
||||
#[serde(default)]
|
||||
#[serde(flatten)]
|
||||
pub classifier_config: Option<ClassifierConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq, Copy, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ClassifierPooling {
|
||||
#[default]
|
||||
CLS,
|
||||
MEAN,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct ClassifierConfig {
|
||||
pub id2label: HashMap<String, String>,
|
||||
pub label2id: HashMap<String, String>,
|
||||
pub classifier_pooling: ClassifierPooling,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -310,7 +329,6 @@ pub struct ModernBert {
|
||||
norm: LayerNorm,
|
||||
layers: Vec<ModernBertLayer>,
|
||||
final_norm: LayerNorm,
|
||||
head: ModernBertHead,
|
||||
local_attention_size: usize,
|
||||
}
|
||||
|
||||
@ -359,14 +377,12 @@ impl ModernBert {
|
||||
config.layer_norm_eps,
|
||||
vb.pp("model.final_norm"),
|
||||
)?;
|
||||
let head = ModernBertHead::load(vb.pp("head"), config)?;
|
||||
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
norm,
|
||||
layers,
|
||||
final_norm,
|
||||
head,
|
||||
local_attention_size: config.local_attention,
|
||||
})
|
||||
}
|
||||
@ -381,7 +397,7 @@ impl ModernBert {
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?;
|
||||
}
|
||||
let xs = xs.apply(&self.final_norm)?.apply(&self.head)?;
|
||||
let xs = xs.apply(&self.final_norm)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
@ -391,17 +407,98 @@ impl ModernBert {
|
||||
pub struct ModernBertForMaskedLM {
|
||||
model: ModernBert,
|
||||
decoder: ModernBertDecoder,
|
||||
head: ModernBertHead,
|
||||
}
|
||||
|
||||
impl ModernBertForMaskedLM {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let model = ModernBert::load(vb.clone(), config)?;
|
||||
let decoder = ModernBertDecoder::load(vb.clone(), config)?;
|
||||
Ok(Self { model, decoder })
|
||||
let head = ModernBertHead::load(vb.pp("head"), config)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
decoder,
|
||||
head,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.model.forward(xs, mask)?.apply(&self.decoder)?;
|
||||
let xs = self
|
||||
.model
|
||||
.forward(xs, mask)?
|
||||
.apply(&self.head)?
|
||||
.apply(&self.decoder)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ModernBertClassifier {
|
||||
classifier: Linear,
|
||||
}
|
||||
|
||||
impl ModernBertClassifier {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
// The decoder weights are tied with the embeddings layer weights
|
||||
let classifier = linear(
|
||||
config.hidden_size,
|
||||
config
|
||||
.classifier_config
|
||||
.as_ref()
|
||||
.map(|cc| cc.id2label.len())
|
||||
.unwrap_or_default(),
|
||||
vb.pp("classifier"),
|
||||
)?;
|
||||
Ok(Self { classifier })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ModernBertClassifier {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.classifier)?;
|
||||
softmax(&xs, D::Minus1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ModernBertForSequenceClassification {
|
||||
model: ModernBert,
|
||||
head: ModernBertHead,
|
||||
classifier: ModernBertClassifier,
|
||||
classifier_pooling: ClassifierPooling,
|
||||
}
|
||||
|
||||
impl ModernBertForSequenceClassification {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let model = ModernBert::load(vb.clone(), config)?;
|
||||
let classifier = ModernBertClassifier::load(vb.clone(), config)?;
|
||||
let head = ModernBertHead::load(vb.pp("head"), config)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
head,
|
||||
classifier,
|
||||
classifier_pooling: config
|
||||
.classifier_config
|
||||
.as_ref()
|
||||
.map(|cc| cc.classifier_pooling)
|
||||
.unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
|
||||
let output = self.model.forward(xs, mask)?;
|
||||
let last_hidden_state = match self.classifier_pooling {
|
||||
ClassifierPooling::CLS => output.i((.., .., 0))?,
|
||||
ClassifierPooling::MEAN => {
|
||||
let unsqueezed_mask = &mask.unsqueeze(D::Minus1)?.to_dtype(DType::F32)?;
|
||||
let sum_output = output.broadcast_mul(unsqueezed_mask)?.sum(1)?;
|
||||
sum_output.broadcast_div(&mask.sum_keepdim(1)?.to_dtype(DType::F32)?)?
|
||||
}
|
||||
};
|
||||
let xs = self
|
||||
.head
|
||||
.forward(&last_hidden_state)?
|
||||
.apply(&self.classifier)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user