mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Add Nvembed v2 model (#2649)
* Update mod.rs * Create mod.rs * Create decoder.rs * Create model.rs * Create main.rs * Create README.md * Update README.md * Update main.rs * Update and rename decoder.rs to embedding.rs * Update mod.rs * Update model.rs
This commit is contained in:
@ -62,6 +62,7 @@ pub mod mobilenetv4;
|
||||
pub mod mobileone;
|
||||
pub mod moondream;
|
||||
pub mod mpt;
|
||||
pub mod nvembed_v2;
|
||||
pub mod olmo;
|
||||
pub mod openclip;
|
||||
pub mod paligemma;
|
||||
|
294
candle-transformers/src/models/nvembed_v2/embedding.rs
Normal file
294
candle-transformers/src/models/nvembed_v2/embedding.rs
Normal file
@ -0,0 +1,294 @@
|
||||
/// Mistral LLM, https://github.com/mistralai/mistral-src
|
||||
use crate::models::{
|
||||
mistral::Config,
|
||||
with_tracing::{linear_no_bias, Linear, RmsNorm},
|
||||
};
|
||||
use crate::utils::repeat_kv;
|
||||
use candle::{DType, Device, Module, Result, Tensor};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let rope_theta = cfg.rope_theta as f32;
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let intermediate_sz = cfg.intermediate_size;
|
||||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
||||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = hidden_sz / num_heads;
|
||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
hidden_size: hidden_sz,
|
||||
rotary_emb,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let key_states = repeat_kv(key_states, self.num_kv_groups)?;
|
||||
let value_states = repeat_kv(value_states, self.num_kv_groups)?;
|
||||
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&value_states)?;
|
||||
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
input_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm =
|
||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
pub cfg: Config,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
cfg: cfg.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
// Attn mask used to mask out padding tokens
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
attn_mask: &Tensor,
|
||||
input_ids: &Tensor,
|
||||
dtype: DType,
|
||||
) -> Result<Tensor> {
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
|
||||
// Expand to 4d mask for sdpa
|
||||
let attn_mask = prepare_4d_attention_mask(attn_mask, dtype, None)?;
|
||||
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, Some(&attn_mask), 0)?;
|
||||
}
|
||||
|
||||
// Return hiddens instead of logits
|
||||
xs.apply(&self.norm)
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_4d_attention_mask(
|
||||
mask: &Tensor,
|
||||
dtype: DType,
|
||||
tgt_len: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let bsz = mask.dims()[0];
|
||||
let src_len = mask.dims()[1];
|
||||
let tgt_len = tgt_len.unwrap_or(src_len);
|
||||
|
||||
let expanded_mask = mask
|
||||
.unsqueeze(1)?
|
||||
.unsqueeze(2)?
|
||||
.expand((bsz, 1, tgt_len, src_len))?
|
||||
.to_dtype(dtype)?;
|
||||
|
||||
let inverted_mask = (1.0 - expanded_mask)?;
|
||||
|
||||
(inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
fn get_dtype_min_val(dtype: DType) -> f64 {
|
||||
match dtype {
|
||||
DType::F32 => f32::MIN as f64,
|
||||
DType::F64 => f64::MIN,
|
||||
_ => panic!("Unsupported data type"),
|
||||
}
|
||||
}
|
18
candle-transformers/src/models/nvembed_v2/mod.rs
Normal file
18
candle-transformers/src/models/nvembed_v2/mod.rs
Normal file
@ -0,0 +1,18 @@
|
||||
//! NV-Embed-v2
|
||||
//!
|
||||
//! NV-Embed-v2 is a text embedding model that combines a Mistral decoder with a latent attention mechanism to produce high-quality text embeddings.
|
||||
//!
|
||||
//! This implementation is based on the [paper](https://arxiv.org/pdf/2405.17428) and [weights](https://huggingface.co/nvidia/NV-Embed-v2)
|
||||
//!
|
||||
//! # Query-Passage Retrieval Example
|
||||
//! ```bash
|
||||
//! cargo run --example nvembed_v2 --release
|
||||
//! ```
|
||||
//!
|
||||
//! # Sentence Embedding Example
|
||||
//! ```bash
|
||||
//! cargo run --example nvembed_v2 --release -- --prompt "Here is a test sentence"
|
||||
//! ```
|
||||
|
||||
pub mod embedding;
|
||||
pub mod model;
|
233
candle-transformers/src/models/nvembed_v2/model.rs
Normal file
233
candle-transformers/src/models/nvembed_v2/model.rs
Normal file
@ -0,0 +1,233 @@
|
||||
use super::embedding::Model as EmbeddingModel;
|
||||
use crate::models::{
|
||||
mistral::Config,
|
||||
with_tracing::{layer_norm, linear, linear_no_bias, LayerNorm, Linear},
|
||||
};
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{ops::softmax_last_dim, LayerNormConfig, Module, VarBuilder};
|
||||
|
||||
// Geglu and feedforward from candle-transformers/src/models/stable_diffusion/attention.rs
|
||||
#[derive(Debug)]
|
||||
struct GeGlu {
|
||||
proj: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl GeGlu {
|
||||
fn new(vs: VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
|
||||
let proj = linear(dim_in, dim_out * 2, vs)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "geglu");
|
||||
Ok(Self { proj, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for GeGlu {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
|
||||
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FeedForward {
|
||||
project_in: GeGlu,
|
||||
linear: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
fn new(vs: VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
|
||||
let inner_dim = dim * mult;
|
||||
let dim_out = dim_out.unwrap_or(dim);
|
||||
let vs = vs.pp("net");
|
||||
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
|
||||
let linear = linear(inner_dim, dim_out, vs.pp("2"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "ff");
|
||||
Ok(Self {
|
||||
project_in,
|
||||
linear,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for FeedForward {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.project_in.forward(xs)?;
|
||||
self.linear.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
// CrossAttention from candle-transformers/src/models/stable_diffusion/attention.rs
|
||||
#[derive(Debug)]
|
||||
struct CrossAttention {
|
||||
to_q: Linear,
|
||||
to_kv: Linear,
|
||||
to_out: Linear,
|
||||
heads: usize,
|
||||
scale: f64,
|
||||
span: tracing::Span,
|
||||
span_attn: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
impl CrossAttention {
|
||||
fn new(
|
||||
vs: VarBuilder,
|
||||
query_dim: usize,
|
||||
context_dim: Option<usize>,
|
||||
heads: usize,
|
||||
dim_head: usize,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = dim_head * heads;
|
||||
let context_dim = context_dim.unwrap_or(query_dim);
|
||||
let scale = 1.0 / f64::sqrt(dim_head as f64);
|
||||
let to_q = linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
|
||||
let to_kv = linear_no_bias(context_dim, inner_dim * 2, vs.pp("to_kv"))?;
|
||||
let to_out = linear_no_bias(inner_dim, query_dim, vs.pp("to_out"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "xa");
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
|
||||
let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
|
||||
Ok(Self {
|
||||
to_q,
|
||||
to_kv,
|
||||
to_out,
|
||||
heads,
|
||||
scale,
|
||||
span,
|
||||
span_attn,
|
||||
span_softmax,
|
||||
})
|
||||
}
|
||||
|
||||
fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (batch_size, seq_len, dim) = xs.dims3()?;
|
||||
xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((batch_size * self.heads, seq_len, dim / self.heads))
|
||||
}
|
||||
|
||||
fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (batch_size, seq_len, dim) = xs.dims3()?;
|
||||
xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((batch_size / self.heads, seq_len, dim * self.heads))
|
||||
}
|
||||
|
||||
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
|
||||
let in_dtype = query.dtype();
|
||||
let query = query.to_dtype(DType::F32)?;
|
||||
let key = key.to_dtype(DType::F32)?;
|
||||
let value = value.to_dtype(DType::F32)?;
|
||||
let xs = query.matmul(&(key.t()? * self.scale)?)?;
|
||||
let xs = {
|
||||
let _enter = self.span_softmax.enter();
|
||||
softmax_last_dim(&xs)?
|
||||
};
|
||||
let xs = xs.matmul(&value)?.to_dtype(in_dtype)?;
|
||||
|
||||
self.reshape_batch_dim_to_heads(&xs)
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query = self.to_q.forward(xs)?;
|
||||
let context = context.unwrap_or(xs).contiguous()?;
|
||||
let kv_chunks = self
|
||||
.to_kv
|
||||
.forward(&context)?
|
||||
.chunk(2, context.shape().dims().len() - 1)?;
|
||||
let (key, value) = (kv_chunks[0].clone(), kv_chunks[1].clone());
|
||||
let query = self.reshape_heads_to_batch_dim(&query)?;
|
||||
let key = self.reshape_heads_to_batch_dim(&key)?;
|
||||
let value = self.reshape_heads_to_batch_dim(&value)?;
|
||||
|
||||
let xs = self.attention(&query, &key, &value)?;
|
||||
self.to_out.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Model {
|
||||
embedding_model: EmbeddingModel,
|
||||
cross_attn: CrossAttention,
|
||||
cross_attn_norm: LayerNorm,
|
||||
cross_attn_context_norm: LayerNorm,
|
||||
ff: FeedForward,
|
||||
ff_norm: LayerNorm,
|
||||
latents: Tensor,
|
||||
pub device: Device,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(vb: VarBuilder) -> Result<Self> {
|
||||
// Embedding model
|
||||
let cfg = Config::config_7b_v0_1(false);
|
||||
let embedding_model = EmbeddingModel::new(&cfg, vb.pp("embedding_model"))?;
|
||||
|
||||
// Latent attention
|
||||
let dim = 4096;
|
||||
let vb = vb.pp("latent_attention_model");
|
||||
let latents = vb.get((512, dim), "latents")?;
|
||||
|
||||
// Cross attend blocks
|
||||
let vb = vb.pp("cross_attend_blocks");
|
||||
let cross_attn_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("0.norm"))?;
|
||||
let cross_attn_context_norm = layer_norm(
|
||||
dim,
|
||||
candle_nn::LayerNormConfig::default(),
|
||||
vb.pp("0.norm_context"),
|
||||
)?;
|
||||
let cross_attn = CrossAttention::new(vb.pp("0.fn"), dim, None, 8, 4096)?;
|
||||
|
||||
let ff_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("1.norm"))?;
|
||||
let ff = FeedForward::new(vb.pp("1.fn"), dim, None, 4)?;
|
||||
|
||||
Ok(Self {
|
||||
embedding_model,
|
||||
cross_attn,
|
||||
cross_attn_norm,
|
||||
cross_attn_context_norm,
|
||||
ff,
|
||||
ff_norm,
|
||||
latents,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
attn_mask: &Tensor,
|
||||
pool_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
// Embedding model
|
||||
let hiddens = self
|
||||
.embedding_model
|
||||
.forward(attn_mask, input_ids, self.dtype)?;
|
||||
|
||||
// Latent attention
|
||||
let b = hiddens.dims()[0];
|
||||
let x = self.latents.unsqueeze(0)?.repeat((b, 1, 1))?;
|
||||
let original_hiddens = &hiddens;
|
||||
|
||||
let hiddens = self.cross_attn_norm.forward(original_hiddens)?;
|
||||
let x = self.cross_attn_context_norm.forward(&x)?;
|
||||
let cross_hiddens = (self.cross_attn.forward(&hiddens, Some(&x))? + original_hiddens)?;
|
||||
|
||||
let hiddens = self.ff_norm.forward(&cross_hiddens)?;
|
||||
let hiddens = (self.ff.forward(&hiddens)? + cross_hiddens)?;
|
||||
|
||||
// Mean pooling
|
||||
let hiddens_masked = hiddens.broadcast_mul(&pool_mask.unsqueeze(D::Minus1)?)?;
|
||||
let s = hiddens_masked.sum(1)?;
|
||||
let d = pool_mask.sum_keepdim(1)?;
|
||||
s.broadcast_div(&d)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user