mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add a quantized version of recurrent-gemma. (#2054)
* Add a quantized version of recurrent-gemma. * Share the rglru part. * Get the quantized gemma model to work.
This commit is contained in:
@ -40,16 +40,20 @@ fn default_max_seq_len() -> usize {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
pub(crate) struct RmsNorm {
|
||||
weight: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
pub(crate) fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(dim, "weight")?;
|
||||
Ok(Self { weight, eps })
|
||||
}
|
||||
|
||||
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
|
||||
Self { weight, eps }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
@ -70,7 +74,7 @@ impl Module for RmsNorm {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
pub(crate) struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
@ -83,7 +87,7 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
if cfg.partial_rotary_factor != 0.5 {
|
||||
candle::bail!("partial-rotary-factor {} <> 0.5", cfg.partial_rotary_factor)
|
||||
}
|
||||
@ -106,7 +110,7 @@ impl RotaryEmbedding {
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
pub(crate) fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
@ -156,15 +160,15 @@ impl Module for Mlp {
|
||||
|
||||
// Real-Gated Linear Recurrent Unit
|
||||
#[derive(Debug, Clone)]
|
||||
struct Rglru {
|
||||
recurrent_param: Tensor,
|
||||
input_gate_weight: Tensor,
|
||||
input_gate_bias: Tensor,
|
||||
recurrent_gate_weight: Tensor,
|
||||
recurrent_gate_bias: Tensor,
|
||||
block_width: usize,
|
||||
n_heads: usize,
|
||||
recurrent_states: Option<Tensor>,
|
||||
pub(crate) struct Rglru {
|
||||
pub(crate) recurrent_param: Tensor,
|
||||
pub(crate) input_gate_weight: Tensor,
|
||||
pub(crate) input_gate_bias: Tensor,
|
||||
pub(crate) recurrent_gate_weight: Tensor,
|
||||
pub(crate) recurrent_gate_bias: Tensor,
|
||||
pub(crate) block_width: usize,
|
||||
pub(crate) n_heads: usize,
|
||||
pub(crate) recurrent_states: Option<Tensor>,
|
||||
}
|
||||
|
||||
fn baddbmm(a: &Tensor, b: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||
@ -200,7 +204,7 @@ impl Rglru {
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L303
|
||||
pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
|
||||
pub(crate) fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, lru_width) = xs.dims3()?;
|
||||
let pos = Tensor::arange(pos as u32, (pos + seq_len) as u32, xs.device())?;
|
||||
let reset = pos.eq(0u32)?.unsqueeze(1)?.unsqueeze(0)?;
|
||||
@ -237,7 +241,7 @@ impl Rglru {
|
||||
reset.broadcast_add(&((1.0 - &reset)?.broadcast_mul(&(1.0 - a_square)?.sqrt()?))?)?;
|
||||
let normalized_x = (gated_inputs * multiplier.to_dtype(xs.dtype()))?;
|
||||
|
||||
let (hidden_states, recurrent_states) = self.rnn_scan(
|
||||
let (hidden_states, recurrent_states) = rnn_scan(
|
||||
&normalized_x,
|
||||
&recurrent_gate,
|
||||
&reset,
|
||||
@ -246,54 +250,53 @@ impl Rglru {
|
||||
self.recurrent_states = Some(recurrent_states);
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
fn rnn_scan(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
recurrent_gate: &Tensor,
|
||||
reset: &Tensor,
|
||||
recurrent_states: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let acc_dtype = DType::F32;
|
||||
let dev = hidden_states.device();
|
||||
let in_dtype = hidden_states.dtype();
|
||||
let inv_reset = (1.0 - reset)?.to_dtype(recurrent_gate.dtype())?;
|
||||
let recurrent_gate = recurrent_gate.broadcast_mul(&inv_reset)?;
|
||||
let (c, r) = if hidden_states.dim(1)? == 1 {
|
||||
match recurrent_states {
|
||||
None => {
|
||||
let next_state = hidden_states.i((.., 0))?.to_dtype(acc_dtype)?;
|
||||
(hidden_states.clone(), next_state)
|
||||
}
|
||||
Some(recurrent_states) => {
|
||||
let contextualized_states =
|
||||
recurrent_gate.to_dtype(acc_dtype)? * recurrent_states.unsqueeze(1)?;
|
||||
let contextualized_states =
|
||||
(contextualized_states + hidden_states.to_dtype(acc_dtype)?)?;
|
||||
let c = contextualized_states.to_dtype(in_dtype)?;
|
||||
let l = contextualized_states.dim(1)?;
|
||||
let r = contextualized_states.i((.., l - 1))?;
|
||||
(c, r)
|
||||
}
|
||||
fn rnn_scan(
|
||||
hidden_states: &Tensor,
|
||||
recurrent_gate: &Tensor,
|
||||
reset: &Tensor,
|
||||
recurrent_states: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let acc_dtype = DType::F32;
|
||||
let dev = hidden_states.device();
|
||||
let in_dtype = hidden_states.dtype();
|
||||
let inv_reset = (1.0 - reset)?.to_dtype(recurrent_gate.dtype())?;
|
||||
let recurrent_gate = recurrent_gate.broadcast_mul(&inv_reset)?;
|
||||
let (c, r) = if hidden_states.dim(1)? == 1 {
|
||||
match recurrent_states {
|
||||
None => {
|
||||
let next_state = hidden_states.i((.., 0))?.to_dtype(acc_dtype)?;
|
||||
(hidden_states.clone(), next_state)
|
||||
}
|
||||
} else {
|
||||
let mut recurrent_states = match recurrent_states {
|
||||
None => Tensor::zeros(hidden_states.i((.., 0))?.shape(), acc_dtype, dev)?,
|
||||
Some(r) => r.clone(),
|
||||
};
|
||||
let mut contextualized_states = vec![];
|
||||
for t in 0..hidden_states.dim(1)? {
|
||||
recurrent_states =
|
||||
(recurrent_gate.i((.., t))?.to_dtype(acc_dtype)? * recurrent_states)?;
|
||||
recurrent_states =
|
||||
(recurrent_states + hidden_states.i((.., t))?.to_dtype(acc_dtype)?)?;
|
||||
contextualized_states.push(recurrent_states.to_dtype(in_dtype)?)
|
||||
Some(recurrent_states) => {
|
||||
let contextualized_states =
|
||||
recurrent_gate.to_dtype(acc_dtype)? * recurrent_states.unsqueeze(1)?;
|
||||
let contextualized_states =
|
||||
(contextualized_states + hidden_states.to_dtype(acc_dtype)?)?;
|
||||
let c = contextualized_states.to_dtype(in_dtype)?;
|
||||
let l = contextualized_states.dim(1)?;
|
||||
let r = contextualized_states.i((.., l - 1))?;
|
||||
(c, r)
|
||||
}
|
||||
let contextualized_states = Tensor::stack(&contextualized_states, 1)?;
|
||||
(contextualized_states, recurrent_states)
|
||||
}
|
||||
} else {
|
||||
let mut recurrent_states = match recurrent_states {
|
||||
None => Tensor::zeros(hidden_states.i((.., 0))?.shape(), acc_dtype, dev)?,
|
||||
Some(r) => r.clone(),
|
||||
};
|
||||
Ok((c, r))
|
||||
}
|
||||
let mut contextualized_states = vec![];
|
||||
for t in 0..hidden_states.dim(1)? {
|
||||
recurrent_states =
|
||||
(recurrent_gate.i((.., t))?.to_dtype(acc_dtype)? * recurrent_states)?;
|
||||
recurrent_states =
|
||||
(recurrent_states + hidden_states.i((.., t))?.to_dtype(acc_dtype)?)?;
|
||||
contextualized_states.push(recurrent_states.to_dtype(in_dtype)?)
|
||||
}
|
||||
let contextualized_states = Tensor::stack(&contextualized_states, 1)?;
|
||||
(contextualized_states, recurrent_states)
|
||||
};
|
||||
Ok((c, r))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
Reference in New Issue
Block a user