Add Nvembed v2 model (#2649)

* Update mod.rs

* Create mod.rs

* Create decoder.rs

* Create model.rs

* Create main.rs

* Create README.md

* Update README.md

* Update main.rs

* Update and rename decoder.rs to embedding.rs

* Update mod.rs

* Update model.rs
This commit is contained in:
cdoko
2024-12-03 05:56:01 -04:00
committed by GitHub
parent 6f715f9256
commit 145aa7193c
6 changed files with 803 additions and 0 deletions

View File

@ -62,6 +62,7 @@ pub mod mobilenetv4;
pub mod mobileone;
pub mod moondream;
pub mod mpt;
pub mod nvembed_v2;
pub mod olmo;
pub mod openclip;
pub mod paligemma;

View File

@ -0,0 +1,294 @@
/// Mistral LLM, https://github.com/mistralai/mistral-src
use crate::models::{
mistral::Config,
with_tracing::{linear_no_bias, Linear, RmsNorm},
};
use crate::utils::repeat_kv;
use candle::{DType, Device, Module, Result, Tensor};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let rope_theta = cfg.rope_theta as f32;
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
}
impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = hidden_sz / num_heads;
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size: hidden_sz,
rotary_emb,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let key_states = repeat_kv(key_states, self.num_kv_groups)?;
let value_states = repeat_kv(value_states, self.num_kv_groups)?;
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&value_states)?;
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
pub cfg: Config,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?;
Ok(Self {
embed_tokens,
layers,
norm,
cfg: cfg.clone(),
})
}
// Attn mask used to mask out padding tokens
pub fn forward(
&mut self,
attn_mask: &Tensor,
input_ids: &Tensor,
dtype: DType,
) -> Result<Tensor> {
let mut xs = self.embed_tokens.forward(input_ids)?;
// Expand to 4d mask for sdpa
let attn_mask = prepare_4d_attention_mask(attn_mask, dtype, None)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, Some(&attn_mask), 0)?;
}
// Return hiddens instead of logits
xs.apply(&self.norm)
}
}
fn prepare_4d_attention_mask(
mask: &Tensor,
dtype: DType,
tgt_len: Option<usize>,
) -> Result<Tensor> {
let bsz = mask.dims()[0];
let src_len = mask.dims()[1];
let tgt_len = tgt_len.unwrap_or(src_len);
let expanded_mask = mask
.unsqueeze(1)?
.unsqueeze(2)?
.expand((bsz, 1, tgt_len, src_len))?
.to_dtype(dtype)?;
let inverted_mask = (1.0 - expanded_mask)?;
(inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype)
}
fn get_dtype_min_val(dtype: DType) -> f64 {
match dtype {
DType::F32 => f32::MIN as f64,
DType::F64 => f64::MIN,
_ => panic!("Unsupported data type"),
}
}

View File

@ -0,0 +1,18 @@
//! NV-Embed-v2
//!
//! NV-Embed-v2 is a text embedding model that combines a Mistral decoder with a latent attention mechanism to produce high-quality text embeddings.
//!
//! This implementation is based on the [paper](https://arxiv.org/pdf/2405.17428) and [weights](https://huggingface.co/nvidia/NV-Embed-v2)
//!
//! # Query-Passage Retrieval Example
//! ```bash
//! cargo run --example nvembed_v2 --release
//! ```
//!
//! # Sentence Embedding Example
//! ```bash
//! cargo run --example nvembed_v2 --release -- --prompt "Here is a test sentence"
//! ```
pub mod embedding;
pub mod model;

View File

@ -0,0 +1,233 @@
use super::embedding::Model as EmbeddingModel;
use crate::models::{
mistral::Config,
with_tracing::{layer_norm, linear, linear_no_bias, LayerNorm, Linear},
};
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{ops::softmax_last_dim, LayerNormConfig, Module, VarBuilder};
// Geglu and feedforward from candle-transformers/src/models/stable_diffusion/attention.rs
#[derive(Debug)]
struct GeGlu {
proj: Linear,
span: tracing::Span,
}
impl GeGlu {
fn new(vs: VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
let proj = linear(dim_in, dim_out * 2, vs)?;
let span = tracing::span!(tracing::Level::TRACE, "geglu");
Ok(Self { proj, span })
}
}
impl Module for GeGlu {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
}
}
#[derive(Debug)]
struct FeedForward {
project_in: GeGlu,
linear: Linear,
span: tracing::Span,
}
impl FeedForward {
fn new(vs: VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
let inner_dim = dim * mult;
let dim_out = dim_out.unwrap_or(dim);
let vs = vs.pp("net");
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
let linear = linear(inner_dim, dim_out, vs.pp("2"))?;
let span = tracing::span!(tracing::Level::TRACE, "ff");
Ok(Self {
project_in,
linear,
span,
})
}
}
impl Module for FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.project_in.forward(xs)?;
self.linear.forward(&xs)
}
}
// CrossAttention from candle-transformers/src/models/stable_diffusion/attention.rs
#[derive(Debug)]
struct CrossAttention {
to_q: Linear,
to_kv: Linear,
to_out: Linear,
heads: usize,
scale: f64,
span: tracing::Span,
span_attn: tracing::Span,
span_softmax: tracing::Span,
}
impl CrossAttention {
fn new(
vs: VarBuilder,
query_dim: usize,
context_dim: Option<usize>,
heads: usize,
dim_head: usize,
) -> Result<Self> {
let inner_dim = dim_head * heads;
let context_dim = context_dim.unwrap_or(query_dim);
let scale = 1.0 / f64::sqrt(dim_head as f64);
let to_q = linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
let to_kv = linear_no_bias(context_dim, inner_dim * 2, vs.pp("to_kv"))?;
let to_out = linear_no_bias(inner_dim, query_dim, vs.pp("to_out"))?;
let span = tracing::span!(tracing::Level::TRACE, "xa");
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
Ok(Self {
to_q,
to_kv,
to_out,
heads,
scale,
span,
span_attn,
span_softmax,
})
}
fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
let (batch_size, seq_len, dim) = xs.dims3()?;
xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
.transpose(1, 2)?
.reshape((batch_size * self.heads, seq_len, dim / self.heads))
}
fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
let (batch_size, seq_len, dim) = xs.dims3()?;
xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
.transpose(1, 2)?
.reshape((batch_size / self.heads, seq_len, dim * self.heads))
}
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let in_dtype = query.dtype();
let query = query.to_dtype(DType::F32)?;
let key = key.to_dtype(DType::F32)?;
let value = value.to_dtype(DType::F32)?;
let xs = query.matmul(&(key.t()? * self.scale)?)?;
let xs = {
let _enter = self.span_softmax.enter();
softmax_last_dim(&xs)?
};
let xs = xs.matmul(&value)?.to_dtype(in_dtype)?;
self.reshape_batch_dim_to_heads(&xs)
}
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let query = self.to_q.forward(xs)?;
let context = context.unwrap_or(xs).contiguous()?;
let kv_chunks = self
.to_kv
.forward(&context)?
.chunk(2, context.shape().dims().len() - 1)?;
let (key, value) = (kv_chunks[0].clone(), kv_chunks[1].clone());
let query = self.reshape_heads_to_batch_dim(&query)?;
let key = self.reshape_heads_to_batch_dim(&key)?;
let value = self.reshape_heads_to_batch_dim(&value)?;
let xs = self.attention(&query, &key, &value)?;
self.to_out.forward(&xs)
}
}
#[derive(Debug)]
pub struct Model {
embedding_model: EmbeddingModel,
cross_attn: CrossAttention,
cross_attn_norm: LayerNorm,
cross_attn_context_norm: LayerNorm,
ff: FeedForward,
ff_norm: LayerNorm,
latents: Tensor,
pub device: Device,
pub dtype: DType,
}
impl Model {
pub fn new(vb: VarBuilder) -> Result<Self> {
// Embedding model
let cfg = Config::config_7b_v0_1(false);
let embedding_model = EmbeddingModel::new(&cfg, vb.pp("embedding_model"))?;
// Latent attention
let dim = 4096;
let vb = vb.pp("latent_attention_model");
let latents = vb.get((512, dim), "latents")?;
// Cross attend blocks
let vb = vb.pp("cross_attend_blocks");
let cross_attn_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("0.norm"))?;
let cross_attn_context_norm = layer_norm(
dim,
candle_nn::LayerNormConfig::default(),
vb.pp("0.norm_context"),
)?;
let cross_attn = CrossAttention::new(vb.pp("0.fn"), dim, None, 8, 4096)?;
let ff_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("1.norm"))?;
let ff = FeedForward::new(vb.pp("1.fn"), dim, None, 4)?;
Ok(Self {
embedding_model,
cross_attn,
cross_attn_norm,
cross_attn_context_norm,
ff,
ff_norm,
latents,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
pub fn forward(
&mut self,
input_ids: &Tensor,
attn_mask: &Tensor,
pool_mask: &Tensor,
) -> Result<Tensor> {
// Embedding model
let hiddens = self
.embedding_model
.forward(attn_mask, input_ids, self.dtype)?;
// Latent attention
let b = hiddens.dims()[0];
let x = self.latents.unsqueeze(0)?.repeat((b, 1, 1))?;
let original_hiddens = &hiddens;
let hiddens = self.cross_attn_norm.forward(original_hiddens)?;
let x = self.cross_attn_context_norm.forward(&x)?;
let cross_hiddens = (self.cross_attn.forward(&hiddens, Some(&x))? + original_hiddens)?;
let hiddens = self.ff_norm.forward(&cross_hiddens)?;
let hiddens = (self.ff.forward(&hiddens)? + cross_hiddens)?;
// Mean pooling
let hiddens_masked = hiddens.broadcast_mul(&pool_mask.unsqueeze(D::Minus1)?)?;
let s = hiddens_masked.sum(1)?;
let d = pool_mask.sum_keepdim(1)?;
s.broadcast_div(&d)
}
}