Add a quantized version of recurrent-gemma. (#2054)

* Add a quantized version of recurrent-gemma.

* Share the rglru part.

* Get the quantized gemma model to work.
This commit is contained in:
Laurent Mazare
2024-04-13 20:07:01 +02:00
committed by GitHub
parent 4c88c3ce06
commit 50e49ecc5f
6 changed files with 521 additions and 67 deletions

View File

@ -40,16 +40,20 @@ fn default_max_seq_len() -> usize {
}
#[derive(Debug, Clone)]
struct RmsNorm {
pub(crate) struct RmsNorm {
weight: Tensor,
eps: f64,
}
impl RmsNorm {
fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
pub(crate) fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(dim, "weight")?;
Ok(Self { weight, eps })
}
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
Self { weight, eps }
}
}
impl Module for RmsNorm {
@ -70,7 +74,7 @@ impl Module for RmsNorm {
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
pub(crate) struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
@ -83,7 +87,7 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
if cfg.partial_rotary_factor != 0.5 {
candle::bail!("partial-rotary-factor {} <> 0.5", cfg.partial_rotary_factor)
}
@ -106,7 +110,7 @@ impl RotaryEmbedding {
})
}
fn apply_rotary_emb_qkv(
pub(crate) fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
@ -156,15 +160,15 @@ impl Module for Mlp {
// Real-Gated Linear Recurrent Unit
#[derive(Debug, Clone)]
struct Rglru {
recurrent_param: Tensor,
input_gate_weight: Tensor,
input_gate_bias: Tensor,
recurrent_gate_weight: Tensor,
recurrent_gate_bias: Tensor,
block_width: usize,
n_heads: usize,
recurrent_states: Option<Tensor>,
pub(crate) struct Rglru {
pub(crate) recurrent_param: Tensor,
pub(crate) input_gate_weight: Tensor,
pub(crate) input_gate_bias: Tensor,
pub(crate) recurrent_gate_weight: Tensor,
pub(crate) recurrent_gate_bias: Tensor,
pub(crate) block_width: usize,
pub(crate) n_heads: usize,
pub(crate) recurrent_states: Option<Tensor>,
}
fn baddbmm(a: &Tensor, b: &Tensor, c: &Tensor) -> Result<Tensor> {
@ -200,7 +204,7 @@ impl Rglru {
}
// https://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L303
pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
pub(crate) fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, lru_width) = xs.dims3()?;
let pos = Tensor::arange(pos as u32, (pos + seq_len) as u32, xs.device())?;
let reset = pos.eq(0u32)?.unsqueeze(1)?.unsqueeze(0)?;
@ -237,7 +241,7 @@ impl Rglru {
reset.broadcast_add(&((1.0 - &reset)?.broadcast_mul(&(1.0 - a_square)?.sqrt()?))?)?;
let normalized_x = (gated_inputs * multiplier.to_dtype(xs.dtype()))?;
let (hidden_states, recurrent_states) = self.rnn_scan(
let (hidden_states, recurrent_states) = rnn_scan(
&normalized_x,
&recurrent_gate,
&reset,
@ -246,54 +250,53 @@ impl Rglru {
self.recurrent_states = Some(recurrent_states);
Ok(hidden_states)
}
}
fn rnn_scan(
&self,
hidden_states: &Tensor,
recurrent_gate: &Tensor,
reset: &Tensor,
recurrent_states: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let acc_dtype = DType::F32;
let dev = hidden_states.device();
let in_dtype = hidden_states.dtype();
let inv_reset = (1.0 - reset)?.to_dtype(recurrent_gate.dtype())?;
let recurrent_gate = recurrent_gate.broadcast_mul(&inv_reset)?;
let (c, r) = if hidden_states.dim(1)? == 1 {
match recurrent_states {
None => {
let next_state = hidden_states.i((.., 0))?.to_dtype(acc_dtype)?;
(hidden_states.clone(), next_state)
}
Some(recurrent_states) => {
let contextualized_states =
recurrent_gate.to_dtype(acc_dtype)? * recurrent_states.unsqueeze(1)?;
let contextualized_states =
(contextualized_states + hidden_states.to_dtype(acc_dtype)?)?;
let c = contextualized_states.to_dtype(in_dtype)?;
let l = contextualized_states.dim(1)?;
let r = contextualized_states.i((.., l - 1))?;
(c, r)
}
fn rnn_scan(
hidden_states: &Tensor,
recurrent_gate: &Tensor,
reset: &Tensor,
recurrent_states: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let acc_dtype = DType::F32;
let dev = hidden_states.device();
let in_dtype = hidden_states.dtype();
let inv_reset = (1.0 - reset)?.to_dtype(recurrent_gate.dtype())?;
let recurrent_gate = recurrent_gate.broadcast_mul(&inv_reset)?;
let (c, r) = if hidden_states.dim(1)? == 1 {
match recurrent_states {
None => {
let next_state = hidden_states.i((.., 0))?.to_dtype(acc_dtype)?;
(hidden_states.clone(), next_state)
}
} else {
let mut recurrent_states = match recurrent_states {
None => Tensor::zeros(hidden_states.i((.., 0))?.shape(), acc_dtype, dev)?,
Some(r) => r.clone(),
};
let mut contextualized_states = vec![];
for t in 0..hidden_states.dim(1)? {
recurrent_states =
(recurrent_gate.i((.., t))?.to_dtype(acc_dtype)? * recurrent_states)?;
recurrent_states =
(recurrent_states + hidden_states.i((.., t))?.to_dtype(acc_dtype)?)?;
contextualized_states.push(recurrent_states.to_dtype(in_dtype)?)
Some(recurrent_states) => {
let contextualized_states =
recurrent_gate.to_dtype(acc_dtype)? * recurrent_states.unsqueeze(1)?;
let contextualized_states =
(contextualized_states + hidden_states.to_dtype(acc_dtype)?)?;
let c = contextualized_states.to_dtype(in_dtype)?;
let l = contextualized_states.dim(1)?;
let r = contextualized_states.i((.., l - 1))?;
(c, r)
}
let contextualized_states = Tensor::stack(&contextualized_states, 1)?;
(contextualized_states, recurrent_states)
}
} else {
let mut recurrent_states = match recurrent_states {
None => Tensor::zeros(hidden_states.i((.., 0))?.shape(), acc_dtype, dev)?,
Some(r) => r.clone(),
};
Ok((c, r))
}
let mut contextualized_states = vec![];
for t in 0..hidden_states.dim(1)? {
recurrent_states =
(recurrent_gate.i((.., t))?.to_dtype(acc_dtype)? * recurrent_states)?;
recurrent_states =
(recurrent_states + hidden_states.i((.., t))?.to_dtype(acc_dtype)?)?;
contextualized_states.push(recurrent_states.to_dtype(in_dtype)?)
}
let contextualized_states = Tensor::stack(&contextualized_states, 1)?;
(contextualized_states, recurrent_states)
};
Ok((c, r))
}
#[derive(Debug, Clone)]