Support embedding model gte-Qwen1.5-7B-instruct (#2190)

* Support embedding model gte-Qwen1.5-7B-instruct

This is a text embedding model based on Qwen2. They share same
model architecture except the last MLP module. This commit brings in
minimal modification of the old Qwen2 implementation to support both
models.

An example is provided, and had been verified according to the official
PyTorch implementation.

* Avoid doing the 'last-token filtering' based on the absence of attention mask.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Yin Guobing
2024-05-17 03:34:10 +08:00
committed by GitHub
parent bdaa34216a
commit 349c3e806a
4 changed files with 260 additions and 16 deletions

View File

@ -1,5 +1,5 @@
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
@ -250,7 +250,6 @@ pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
sliding_window: usize,
device: Device,
dtype: DType,
@ -269,19 +268,17 @@ impl Model {
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
sliding_window: cfg.sliding_window,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn prepare_decoder_attention_mask(
fn prepare_causal_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
@ -301,7 +298,7 @@ impl Model {
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
@ -310,21 +307,42 @@ impl Model {
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result<Tensor> {
let (b_sz, sql_len) = attn_mask.dims2()?;
let mut mask: Vec<Tensor> = vec![];
for b in 0..b_sz {
mask.push(attn_mask.i((b, ..))?.expand((1, 1, sql_len, sql_len))?);
}
let mask = Tensor::cat(&mask, 0)?;
let on_true = mask.zeros_like()?.to_dtype(self.dtype)?;
let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)?
.broadcast_as(mask.shape())?
.to_dtype(self.dtype)?;
mask.where_cond(&on_true, &on_false)
}
pub fn forward(
&mut self,
input_ids: &Tensor,
seqlen_offset: usize,
attn_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
let attention_mask: Option<Tensor> = match attn_mask {
Some(mask) => Some(self.prepare_attention_mask(mask)?),
None => {
if seq_len <= 1 {
None
} else {
Some(self.prepare_causal_attention_mask(b_size, seq_len, seqlen_offset)?)
}
}
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
xs.apply(&self.norm)
}
pub fn clear_kv_cache(&mut self) {
@ -333,3 +351,32 @@ impl Model {
}
}
}
#[derive(Debug, Clone)]
pub struct ModelForCausalLM {
base_model: Model,
lm_head: Linear,
}
impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let base_model = Model::new(cfg, vb)?;
Ok(Self {
base_model,
lm_head,
})
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, seq_len) = input_ids.dims2()?;
self.base_model
.forward(input_ids, seqlen_offset, None)?
.narrow(1, seq_len - 1, 1)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
self.base_model.clear_kv_cache()
}
}