mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Support embedding model gte-Qwen1.5-7B-instruct (#2190)
* Support embedding model gte-Qwen1.5-7B-instruct This is a text embedding model based on Qwen2. They share same model architecture except the last MLP module. This commit brings in minimal modification of the old Qwen2 implementation to support both models. An example is provided, and had been verified according to the official PyTorch implementation. * Avoid doing the 'last-token filtering' based on the absence of attention mask. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -250,7 +250,6 @@ pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
sliding_window: usize,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
@ -269,19 +268,17 @@ impl Model {
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
sliding_window: cfg.sliding_window,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
fn prepare_causal_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
@ -301,7 +298,7 @@ impl Model {
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
@ -310,21 +307,42 @@ impl Model {
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, sql_len) = attn_mask.dims2()?;
|
||||
let mut mask: Vec<Tensor> = vec![];
|
||||
for b in 0..b_sz {
|
||||
mask.push(attn_mask.i((b, ..))?.expand((1, 1, sql_len, sql_len))?);
|
||||
}
|
||||
let mask = Tensor::cat(&mask, 0)?;
|
||||
let on_true = mask.zeros_like()?.to_dtype(self.dtype)?;
|
||||
let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)?
|
||||
.broadcast_as(mask.shape())?
|
||||
.to_dtype(self.dtype)?;
|
||||
mask.where_cond(&on_true, &on_false)
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
attn_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
let attention_mask: Option<Tensor> = match attn_mask {
|
||||
Some(mask) => Some(self.prepare_attention_mask(mask)?),
|
||||
None => {
|
||||
if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.prepare_causal_attention_mask(b_size, seq_len, seqlen_offset)?)
|
||||
}
|
||||
}
|
||||
};
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
xs.apply(&self.norm)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
@ -333,3 +351,32 @@ impl Model {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelForCausalLM {
|
||||
base_model: Model,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl ModelForCausalLM {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let base_model = Model::new(cfg, vb)?;
|
||||
Ok(Self {
|
||||
base_model,
|
||||
lm_head,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (_b_size, seq_len) = input_ids.dims2()?;
|
||||
self.base_model
|
||||
.forward(input_ids, seqlen_offset, None)?
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.base_model.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user