mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
ModernBERT model (#2713)
* layer_norm_no_bias * Modernbert model. * Format + cleanup error. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -60,6 +60,7 @@ pub mod mmdit;
|
||||
pub mod mobileclip;
|
||||
pub mod mobilenetv4;
|
||||
pub mod mobileone;
|
||||
pub mod modernbert;
|
||||
pub mod moondream;
|
||||
pub mod mpt;
|
||||
pub mod nvembed_v2;
|
||||
|
407
candle-transformers/src/models/modernbert.rs
Normal file
407
candle-transformers/src/models/modernbert.rs
Normal file
@ -0,0 +1,407 @@
|
||||
//! ModernBERT
|
||||
//!
|
||||
//! ModernBERT is a modernized bidirectional encoder-only Transformer model.
|
||||
//! - [Arxiv](https://arxiv.org/abs/2412.13663) "Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference"
|
||||
//! - Upstream [Github repo](https://github.com/AnswerDotAI/ModernBERT).
|
||||
//! - See modernbert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code
|
||||
//!
|
||||
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm_no_bias, linear_no_bias, ops::softmax, Embedding, LayerNorm, Linear,
|
||||
Module, VarBuilder,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use core::f32;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub layer_norm_eps: f64,
|
||||
pub pad_token_id: u32,
|
||||
pub global_attn_every_n_layers: usize,
|
||||
pub global_rope_theta: f64,
|
||||
pub local_attention: usize,
|
||||
pub local_rope_theta: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, config: &Config, rope_theta: f64, dev: &Device) -> Result<Self> {
|
||||
let dim = config.hidden_size / config.num_attention_heads;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let max_seq_len = config.max_position_embeddings;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &self.cos, &self.sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &self.cos, &self.sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ModernBertAttention {
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
num_attention_heads: usize,
|
||||
attention_head_size: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
}
|
||||
|
||||
impl ModernBertAttention {
|
||||
fn load(vb: VarBuilder, config: &Config, rotary_emb: Arc<RotaryEmbedding>) -> Result<Self> {
|
||||
let num_attention_heads = config.num_attention_heads;
|
||||
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||
|
||||
let qkv = linear_no_bias(config.hidden_size, config.hidden_size * 3, vb.pp("Wqkv"))?;
|
||||
let proj = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("Wo"))?;
|
||||
|
||||
Ok(Self {
|
||||
qkv,
|
||||
proj,
|
||||
num_attention_heads,
|
||||
attention_head_size,
|
||||
rotary_emb,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let xs = hidden_states.clone();
|
||||
let (b, seq_len, d) = xs.dims3()?;
|
||||
let qkv = xs
|
||||
.apply(&self.qkv)?
|
||||
.reshape((
|
||||
b,
|
||||
seq_len,
|
||||
3,
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
))?
|
||||
.permute((2, 0, 3, 1, 4))?;
|
||||
|
||||
let q = qkv.get(0)?;
|
||||
let k = qkv.get(1)?;
|
||||
let v = qkv.get(2)?;
|
||||
|
||||
let (q, k) = self.rotary_emb.apply_rotary_emb_qkv(&q, &k)?;
|
||||
|
||||
let scale = (self.attention_head_size as f64).powf(-0.5);
|
||||
let q = (q * scale)?;
|
||||
|
||||
let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
|
||||
|
||||
let att = att.broadcast_add(attention_mask)?;
|
||||
let att = softmax(&att, D::Minus1)?;
|
||||
|
||||
let xs = att.matmul(&v)?;
|
||||
|
||||
let xs = xs.transpose(1, 2)?.reshape((b, seq_len, d))?;
|
||||
let xs = xs.apply(&self.proj)?;
|
||||
let xs = xs.reshape((b, seq_len, d))?;
|
||||
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ModernBertMLP {
|
||||
wi: Linear,
|
||||
wo: Linear,
|
||||
}
|
||||
|
||||
impl ModernBertMLP {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let wi = linear_no_bias(
|
||||
config.hidden_size,
|
||||
config.intermediate_size * 2,
|
||||
vb.pp("Wi"),
|
||||
)?;
|
||||
let wo = linear_no_bias(config.intermediate_size, config.hidden_size, vb.pp("Wo"))?;
|
||||
Ok(Self { wi, wo })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ModernBertMLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.wi)?;
|
||||
let xs = xs.chunk(2, D::Minus1)?;
|
||||
let xs = (&xs[0].gelu_erf()? * &xs[1])?.apply(&self.wo)?; // GeGLU
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ModernBertLayer {
|
||||
attn: ModernBertAttention,
|
||||
mlp: ModernBertMLP,
|
||||
attn_norm: Option<LayerNorm>,
|
||||
mlp_norm: LayerNorm,
|
||||
uses_local_attention: bool,
|
||||
}
|
||||
|
||||
impl ModernBertLayer {
|
||||
fn load(
|
||||
vb: VarBuilder,
|
||||
config: &Config,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
uses_local_attention: bool,
|
||||
) -> Result<Self> {
|
||||
let attn = ModernBertAttention::load(vb.pp("attn"), config, rotary_emb)?;
|
||||
let mlp = ModernBertMLP::load(vb.pp("mlp"), config)?;
|
||||
let attn_norm = layer_norm_no_bias(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("attn_norm"),
|
||||
)
|
||||
.ok();
|
||||
let mlp_norm =
|
||||
layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("mlp_norm"))?;
|
||||
Ok(Self {
|
||||
attn,
|
||||
mlp,
|
||||
attn_norm,
|
||||
mlp_norm,
|
||||
uses_local_attention,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
global_attention_mask: &Tensor,
|
||||
local_attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let mut xs = xs.clone();
|
||||
if let Some(norm) = &self.attn_norm {
|
||||
xs = xs.apply(norm)?;
|
||||
}
|
||||
|
||||
let attention_mask = if self.uses_local_attention {
|
||||
&global_attention_mask.broadcast_add(local_attention_mask)?
|
||||
} else {
|
||||
global_attention_mask
|
||||
};
|
||||
let xs = self.attn.forward(&xs, attention_mask)?;
|
||||
let xs = (xs + residual)?;
|
||||
let mlp_out = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?;
|
||||
let xs = (xs + mlp_out)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ModernBertHead {
|
||||
dense: Linear,
|
||||
norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl ModernBertHead {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
|
||||
let norm = layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("norm"))?;
|
||||
Ok(Self { dense, norm })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ModernBertHead {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.dense)?.gelu_erf()?.apply(&self.norm)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ModernBertDecoder {
|
||||
decoder: Linear,
|
||||
}
|
||||
|
||||
impl ModernBertDecoder {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
// The decoder weights are tied with the embeddings layer weights
|
||||
let decoder_weights = vb.get(
|
||||
(config.vocab_size, config.hidden_size),
|
||||
"model.embeddings.tok_embeddings.weight",
|
||||
)?;
|
||||
let decoder_bias = vb.get(config.vocab_size, "decoder.bias")?;
|
||||
let decoder = Linear::new(decoder_weights, Some(decoder_bias));
|
||||
Ok(Self { decoder })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ModernBertDecoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.decoder)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// Global attention mask calculated from padded token inputs
|
||||
fn prepare_4d_attention_mask(
|
||||
mask: &Tensor,
|
||||
dtype: DType,
|
||||
tgt_len: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let bsz = mask.dim(0)?;
|
||||
let src_len = mask.dim(1)?;
|
||||
let tgt_len = tgt_len.unwrap_or(src_len);
|
||||
|
||||
let expanded_mask = mask
|
||||
.unsqueeze(1)?
|
||||
.unsqueeze(2)?
|
||||
.expand((bsz, 1, tgt_len, src_len))?
|
||||
.to_dtype(dtype)?;
|
||||
|
||||
let inverted_mask = (1.0 - expanded_mask)?;
|
||||
|
||||
(inverted_mask * f32::MIN as f64)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
// Attention mask caused by the sliding window
|
||||
fn get_local_attention_mask(
|
||||
seq_len: usize,
|
||||
max_distance: usize,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| {
|
||||
(0..seq_len).map(move |j| {
|
||||
if (j as i32 - i as i32).abs() > max_distance as i32 {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (seq_len, seq_len), device)
|
||||
}
|
||||
|
||||
// ModernBERT backbone
|
||||
#[derive(Clone)]
|
||||
pub struct ModernBert {
|
||||
word_embeddings: Embedding,
|
||||
norm: LayerNorm,
|
||||
layers: Vec<ModernBertLayer>,
|
||||
final_norm: LayerNorm,
|
||||
head: ModernBertHead,
|
||||
local_attention_size: usize,
|
||||
}
|
||||
|
||||
impl ModernBert {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let word_embeddings = embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
vb.pp("model.embeddings.tok_embeddings"),
|
||||
)?;
|
||||
let norm = layer_norm_no_bias(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("model.embeddings.norm"),
|
||||
)?;
|
||||
let global_rotary_emb = Arc::new(RotaryEmbedding::new(
|
||||
vb.dtype(),
|
||||
config,
|
||||
config.global_rope_theta,
|
||||
vb.device(),
|
||||
)?);
|
||||
let local_rotary_emb = Arc::new(RotaryEmbedding::new(
|
||||
vb.dtype(),
|
||||
config,
|
||||
config.local_rope_theta,
|
||||
vb.device(),
|
||||
)?);
|
||||
|
||||
let mut layers = Vec::with_capacity(config.num_hidden_layers);
|
||||
for layer_id in 0..config.num_hidden_layers {
|
||||
let layer_uses_local_attention = layer_id % config.global_attn_every_n_layers != 0;
|
||||
layers.push(ModernBertLayer::load(
|
||||
vb.pp(format!("model.layers.{layer_id}")),
|
||||
config,
|
||||
if layer_uses_local_attention {
|
||||
local_rotary_emb.clone()
|
||||
} else {
|
||||
global_rotary_emb.clone()
|
||||
},
|
||||
layer_uses_local_attention,
|
||||
)?);
|
||||
}
|
||||
|
||||
let final_norm = layer_norm_no_bias(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("model.final_norm"),
|
||||
)?;
|
||||
let head = ModernBertHead::load(vb.pp("head"), config)?;
|
||||
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
norm,
|
||||
layers,
|
||||
final_norm,
|
||||
head,
|
||||
local_attention_size: config.local_attention,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
|
||||
let seq_len = xs.shape().dims()[1];
|
||||
let global_attention_mask =
|
||||
prepare_4d_attention_mask(mask, DType::F32, None)?.to_device(xs.device())?;
|
||||
let local_attention_mask =
|
||||
get_local_attention_mask(seq_len, self.local_attention_size / 2, xs.device())?;
|
||||
let mut xs = xs.apply(&self.word_embeddings)?.apply(&self.norm)?;
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?;
|
||||
}
|
||||
let xs = xs.apply(&self.final_norm)?.apply(&self.head)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// ModernBERT for the fill-mask task
|
||||
#[derive(Clone)]
|
||||
pub struct ModernBertForMaskedLM {
|
||||
model: ModernBert,
|
||||
decoder: ModernBertDecoder,
|
||||
}
|
||||
|
||||
impl ModernBertForMaskedLM {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let model = ModernBert::load(vb.clone(), config)?;
|
||||
let decoder = ModernBertDecoder::load(vb.clone(), config)?;
|
||||
Ok(Self { model, decoder })
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.model.forward(xs, mask)?.apply(&self.decoder)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user