ModernBERT model (#2713)

* layer_norm_no_bias

* Modernbert model.

* Format + cleanup error.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Jani Monoses
2025-01-13 09:39:27 +02:00
committed by GitHub
parent 2344c4e4b8
commit 461e8c1685
6 changed files with 612 additions and 1 deletions

View File

@ -0,0 +1,12 @@
# candle-modernbert
ModernBERT is a bidirectional encoder-only language model. In this example it is used for the fill-mask task:
## Usage
```bash
cargo run --example modernbert --release -- --model modern-bert-large --prompt 'The capital of France is [MASK].'
```
```markdown
Sentence: 1 : The capital of France is Paris.
```

View File

@ -0,0 +1,180 @@
use std::path::PathBuf;
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::modernbert;
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Debug, Clone, ValueEnum)]
enum Model {
ModernBertBase,
ModernBertLarge,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long, default_value = "modern-bert-base")]
model: Model,
// Path to the tokenizer file.
#[arg(long)]
tokenizer_file: Option<String>,
// Path to the weight files.
#[arg(long)]
weight_files: Option<String>,
// Path to the config file.
#[arg(long)]
config_file: Option<String>,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.model {
Model::ModernBertBase => "answerdotai/ModernBERT-base".to_string(),
Model::ModernBertLarge => "answerdotai/ModernBERT-large".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let weights_filename = match args.weight_files {
Some(files) => PathBuf::from(files),
None => match repo.get("model.safetensors") {
Ok(safetensors) => safetensors,
Err(_) => match repo.get("pytorch_model.bin") {
Ok(pytorch_model) => pytorch_model,
Err(e) => {
anyhow::bail!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {e}")
}
},
},
};
let config = std::fs::read_to_string(config_filename)?;
let config: modernbert::Config = serde_json::from_str(&config)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(args.cpu)?;
let vb = if weights_filename.ends_with("model.safetensors") {
unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F32, &device)
.unwrap()
}
} else {
println!("Loading weights from pytorch_model.bin");
VarBuilder::from_pth(&weights_filename, candle::DType::F32, &device).unwrap()
};
tokenizer
.with_padding(Some(PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
pad_id: config.pad_token_id,
..Default::default()
}))
.with_truncation(None)
.map_err(E::msg)?;
let prompt = match &args.prompt {
Some(p) => vec![p.as_str()],
None => vec![
"Hello I'm a [MASK] model.",
"I'm a [MASK] boy.",
"I'm [MASK] in berlin.",
"The capital of France is [MASK].",
],
};
let model = modernbert::ModernBertForMaskedLM::load(vb, &config)?;
let input_ids = tokenize_batch(&tokenizer, prompt.clone(), &device)?;
let attention_mask = get_attention_mask(&tokenizer, prompt.clone(), &device)?;
let output = model
.forward(&input_ids, &attention_mask)?
.to_dtype(candle::DType::F32)?;
let max_outs = output.argmax(2)?;
let max_out = max_outs.to_vec2::<u32>()?;
let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect();
let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap();
for (i, sentence) in decoded.iter().enumerate() {
println!("Sentence: {} : {}", i + 1, sentence);
}
Ok(())
}
pub fn tokenize_batch(
tokenizer: &Tokenizer,
input: Vec<&str>,
device: &Device,
) -> anyhow::Result<Tensor> {
let tokens = tokenizer.encode_batch(input, true).map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
Ok(Tensor::stack(&token_ids, 0)?)
}
pub fn get_attention_mask(
tokenizer: &Tokenizer,
input: Vec<&str>,
device: &Device,
) -> anyhow::Result<Tensor> {
let tokens = tokenizer.encode_batch(input, true).map_err(E::msg)?;
let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
Ok(Tensor::stack(&attention_mask, 0)?)
}

View File

@ -155,6 +155,15 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
}) })
} }
pub fn layer_norm_no_bias(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
let config = LayerNormConfig {
eps,
remove_mean: true,
affine: false,
};
layer_norm(size, config, vb)
}
/// RmsNorm is a specialized version of the LayerNorm module. /// RmsNorm is a specialized version of the LayerNorm module.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct RmsNorm(LayerNorm); pub struct RmsNorm(LayerNorm);

View File

@ -46,7 +46,9 @@ pub use embedding::{embedding, Embedding};
pub use func::{func, func_t, Func, FuncT}; pub use func::{func, func_t, Func, FuncT};
pub use group_norm::{group_norm, GroupNorm}; pub use group_norm::{group_norm, GroupNorm};
pub use init::Init; pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use layer_norm::{
layer_norm, layer_norm_no_bias, rms_norm, LayerNorm, LayerNormConfig, RmsNorm,
};
pub use linear::{linear, linear_b, linear_no_bias, Linear}; pub use linear::{linear, linear_b, linear_no_bias, Linear};
pub use ops::Dropout; pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};

View File

@ -60,6 +60,7 @@ pub mod mmdit;
pub mod mobileclip; pub mod mobileclip;
pub mod mobilenetv4; pub mod mobilenetv4;
pub mod mobileone; pub mod mobileone;
pub mod modernbert;
pub mod moondream; pub mod moondream;
pub mod mpt; pub mod mpt;
pub mod nvembed_v2; pub mod nvembed_v2;

View File

@ -0,0 +1,407 @@
//! ModernBERT
//!
//! ModernBERT is a modernized bidirectional encoder-only Transformer model.
//! - [Arxiv](https://arxiv.org/abs/2412.13663) "Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference"
//! - Upstream [Github repo](https://github.com/AnswerDotAI/ModernBERT).
//! - See modernbert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code
//!
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm_no_bias, linear_no_bias, ops::softmax, Embedding, LayerNorm, Linear,
Module, VarBuilder,
};
use serde::Deserialize;
use core::f32;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub max_position_embeddings: usize,
pub layer_norm_eps: f64,
pub pad_token_id: u32,
pub global_attn_every_n_layers: usize,
pub global_rope_theta: f64,
pub local_attention: usize,
pub local_rope_theta: f64,
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(dtype: DType, config: &Config, rope_theta: f64, dev: &Device) -> Result<Self> {
let dim = config.hidden_size / config.num_attention_heads;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let max_seq_len = config.max_position_embeddings;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &self.cos, &self.sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &self.cos, &self.sin)?;
Ok((q_embed, k_embed))
}
}
#[derive(Clone)]
struct ModernBertAttention {
qkv: Linear,
proj: Linear,
num_attention_heads: usize,
attention_head_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
}
impl ModernBertAttention {
fn load(vb: VarBuilder, config: &Config, rotary_emb: Arc<RotaryEmbedding>) -> Result<Self> {
let num_attention_heads = config.num_attention_heads;
let attention_head_size = config.hidden_size / config.num_attention_heads;
let qkv = linear_no_bias(config.hidden_size, config.hidden_size * 3, vb.pp("Wqkv"))?;
let proj = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("Wo"))?;
Ok(Self {
qkv,
proj,
num_attention_heads,
attention_head_size,
rotary_emb,
})
}
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let xs = hidden_states.clone();
let (b, seq_len, d) = xs.dims3()?;
let qkv = xs
.apply(&self.qkv)?
.reshape((
b,
seq_len,
3,
self.num_attention_heads,
self.attention_head_size,
))?
.permute((2, 0, 3, 1, 4))?;
let q = qkv.get(0)?;
let k = qkv.get(1)?;
let v = qkv.get(2)?;
let (q, k) = self.rotary_emb.apply_rotary_emb_qkv(&q, &k)?;
let scale = (self.attention_head_size as f64).powf(-0.5);
let q = (q * scale)?;
let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
let att = att.broadcast_add(attention_mask)?;
let att = softmax(&att, D::Minus1)?;
let xs = att.matmul(&v)?;
let xs = xs.transpose(1, 2)?.reshape((b, seq_len, d))?;
let xs = xs.apply(&self.proj)?;
let xs = xs.reshape((b, seq_len, d))?;
Ok(xs)
}
}
#[derive(Clone)]
pub struct ModernBertMLP {
wi: Linear,
wo: Linear,
}
impl ModernBertMLP {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let wi = linear_no_bias(
config.hidden_size,
config.intermediate_size * 2,
vb.pp("Wi"),
)?;
let wo = linear_no_bias(config.intermediate_size, config.hidden_size, vb.pp("Wo"))?;
Ok(Self { wi, wo })
}
}
impl Module for ModernBertMLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.wi)?;
let xs = xs.chunk(2, D::Minus1)?;
let xs = (&xs[0].gelu_erf()? * &xs[1])?.apply(&self.wo)?; // GeGLU
Ok(xs)
}
}
#[derive(Clone)]
pub struct ModernBertLayer {
attn: ModernBertAttention,
mlp: ModernBertMLP,
attn_norm: Option<LayerNorm>,
mlp_norm: LayerNorm,
uses_local_attention: bool,
}
impl ModernBertLayer {
fn load(
vb: VarBuilder,
config: &Config,
rotary_emb: Arc<RotaryEmbedding>,
uses_local_attention: bool,
) -> Result<Self> {
let attn = ModernBertAttention::load(vb.pp("attn"), config, rotary_emb)?;
let mlp = ModernBertMLP::load(vb.pp("mlp"), config)?;
let attn_norm = layer_norm_no_bias(
config.hidden_size,
config.layer_norm_eps,
vb.pp("attn_norm"),
)
.ok();
let mlp_norm =
layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("mlp_norm"))?;
Ok(Self {
attn,
mlp,
attn_norm,
mlp_norm,
uses_local_attention,
})
}
fn forward(
&self,
xs: &Tensor,
global_attention_mask: &Tensor,
local_attention_mask: &Tensor,
) -> Result<Tensor> {
let residual = xs.clone();
let mut xs = xs.clone();
if let Some(norm) = &self.attn_norm {
xs = xs.apply(norm)?;
}
let attention_mask = if self.uses_local_attention {
&global_attention_mask.broadcast_add(local_attention_mask)?
} else {
global_attention_mask
};
let xs = self.attn.forward(&xs, attention_mask)?;
let xs = (xs + residual)?;
let mlp_out = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?;
let xs = (xs + mlp_out)?;
Ok(xs)
}
}
#[derive(Clone)]
pub struct ModernBertHead {
dense: Linear,
norm: LayerNorm,
}
impl ModernBertHead {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let dense = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
let norm = layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("norm"))?;
Ok(Self { dense, norm })
}
}
impl Module for ModernBertHead {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.dense)?.gelu_erf()?.apply(&self.norm)?;
Ok(xs)
}
}
#[derive(Clone)]
pub struct ModernBertDecoder {
decoder: Linear,
}
impl ModernBertDecoder {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
// The decoder weights are tied with the embeddings layer weights
let decoder_weights = vb.get(
(config.vocab_size, config.hidden_size),
"model.embeddings.tok_embeddings.weight",
)?;
let decoder_bias = vb.get(config.vocab_size, "decoder.bias")?;
let decoder = Linear::new(decoder_weights, Some(decoder_bias));
Ok(Self { decoder })
}
}
impl Module for ModernBertDecoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.decoder)?;
Ok(xs)
}
}
// Global attention mask calculated from padded token inputs
fn prepare_4d_attention_mask(
mask: &Tensor,
dtype: DType,
tgt_len: Option<usize>,
) -> Result<Tensor> {
let bsz = mask.dim(0)?;
let src_len = mask.dim(1)?;
let tgt_len = tgt_len.unwrap_or(src_len);
let expanded_mask = mask
.unsqueeze(1)?
.unsqueeze(2)?
.expand((bsz, 1, tgt_len, src_len))?
.to_dtype(dtype)?;
let inverted_mask = (1.0 - expanded_mask)?;
(inverted_mask * f32::MIN as f64)?.to_dtype(dtype)
}
// Attention mask caused by the sliding window
fn get_local_attention_mask(
seq_len: usize,
max_distance: usize,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| {
(0..seq_len).map(move |j| {
if (j as i32 - i as i32).abs() > max_distance as i32 {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect();
Tensor::from_slice(&mask, (seq_len, seq_len), device)
}
// ModernBERT backbone
#[derive(Clone)]
pub struct ModernBert {
word_embeddings: Embedding,
norm: LayerNorm,
layers: Vec<ModernBertLayer>,
final_norm: LayerNorm,
head: ModernBertHead,
local_attention_size: usize,
}
impl ModernBert {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let word_embeddings = embedding(
config.vocab_size,
config.hidden_size,
vb.pp("model.embeddings.tok_embeddings"),
)?;
let norm = layer_norm_no_bias(
config.hidden_size,
config.layer_norm_eps,
vb.pp("model.embeddings.norm"),
)?;
let global_rotary_emb = Arc::new(RotaryEmbedding::new(
vb.dtype(),
config,
config.global_rope_theta,
vb.device(),
)?);
let local_rotary_emb = Arc::new(RotaryEmbedding::new(
vb.dtype(),
config,
config.local_rope_theta,
vb.device(),
)?);
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for layer_id in 0..config.num_hidden_layers {
let layer_uses_local_attention = layer_id % config.global_attn_every_n_layers != 0;
layers.push(ModernBertLayer::load(
vb.pp(format!("model.layers.{layer_id}")),
config,
if layer_uses_local_attention {
local_rotary_emb.clone()
} else {
global_rotary_emb.clone()
},
layer_uses_local_attention,
)?);
}
let final_norm = layer_norm_no_bias(
config.hidden_size,
config.layer_norm_eps,
vb.pp("model.final_norm"),
)?;
let head = ModernBertHead::load(vb.pp("head"), config)?;
Ok(Self {
word_embeddings,
norm,
layers,
final_norm,
head,
local_attention_size: config.local_attention,
})
}
fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
let seq_len = xs.shape().dims()[1];
let global_attention_mask =
prepare_4d_attention_mask(mask, DType::F32, None)?.to_device(xs.device())?;
let local_attention_mask =
get_local_attention_mask(seq_len, self.local_attention_size / 2, xs.device())?;
let mut xs = xs.apply(&self.word_embeddings)?.apply(&self.norm)?;
for layer in self.layers.iter() {
xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?;
}
let xs = xs.apply(&self.final_norm)?.apply(&self.head)?;
Ok(xs)
}
}
// ModernBERT for the fill-mask task
#[derive(Clone)]
pub struct ModernBertForMaskedLM {
model: ModernBert,
decoder: ModernBertDecoder,
}
impl ModernBertForMaskedLM {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let model = ModernBert::load(vb.clone(), config)?;
let decoder = ModernBertDecoder::load(vb.clone(), config)?;
Ok(Self { model, decoder })
}
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
let xs = self.model.forward(xs, mask)?.apply(&self.decoder)?;
Ok(xs)
}
}