From 50e49ecc5f4e72c3807483e55835719490d59d18 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 13 Apr 2024 20:07:01 +0200 Subject: [PATCH] Add a quantized version of recurrent-gemma. (#2054) * Add a quantized version of recurrent-gemma. * Share the rglru part. * Get the quantized gemma model to work. --- README.md | 5 +- .../examples/recurrent-gemma/main.rs | 45 +- candle-transformers/src/models/mod.rs | 1 + .../src/models/quantized_recurrent_gemma.rs | 412 ++++++++++++++++++ .../src/models/recurrent_gemma.rs | 123 +++--- .../src/quantized_var_builder.rs | 2 +- 6 files changed, 521 insertions(+), 67 deletions(-) create mode 100644 candle-transformers/src/models/quantized_recurrent_gemma.rs diff --git a/README.md b/README.md index b9e603b2..ad72f30f 100644 --- a/README.md +++ b/README.md @@ -63,8 +63,9 @@ We also provide a some command line based examples using state of the art models - [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes the SOLAR-10.7B variant. - [Falcon](./candle-examples/examples/falcon/): general LLM. -- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google - Deepmind. +- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind. +- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b + Griffin based models from Google that mix attention with a RNN like state. - [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b. - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM pre-trained on 1T tokens of English and code datasets. Also supports diff --git a/candle-examples/examples/recurrent-gemma/main.rs b/candle-examples/examples/recurrent-gemma/main.rs index 121bea7e..b28acb0d 100644 --- a/candle-examples/examples/recurrent-gemma/main.rs +++ b/candle-examples/examples/recurrent-gemma/main.rs @@ -7,7 +7,8 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -use candle_transformers::models::recurrent_gemma::{Config, Model}; +use candle_transformers::models::quantized_recurrent_gemma::Model as QModel; +use candle_transformers::models::recurrent_gemma::{Config, Model as BModel}; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -16,6 +17,20 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; +enum Model { + B(BModel), + Q(QModel), +} + +impl Model { + fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result { + match self { + Self::B(m) => m.forward(xs, pos), + Self::Q(m) => m.forward(xs, pos), + } + } +} + #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] enum Which { #[value(name = "2b")] @@ -195,6 +210,9 @@ struct Args { /// The model to use. #[arg(long, default_value = "2b")] which: Which, + + #[arg(long)] + quantized: bool, } fn main() -> Result<()> { @@ -250,7 +268,18 @@ fn main() -> Result<()> { .split(',') .map(std::path::PathBuf::from) .collect::>(), - None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + None => { + if args.quantized { + let filename = match args.which { + Which::Base2B => "recurrent-gemma-2b-q4k.gguf", + Which::Instruct2B => "recurrent-gemma-7b-q4k.gguf", + }; + let filename = api.model("lmz/candle-gemma".to_string()).get(filename)?; + vec![filename] + } else { + candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? + } + } }; println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; @@ -263,8 +292,16 @@ fn main() -> Result<()> { } else { DType::F32 }; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - let model = Model::new(&config, vb.pp("model"))?; + let model = if args.quantized { + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &filenames[0], + &device, + )?; + Model::Q(QModel::new(&config, vb.pp("model"))?) + } else { + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + Model::B(BModel::new(&config, vb.pp("model"))?) + }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 21938baa..f4a71931 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -37,6 +37,7 @@ pub mod quantized_mistral; pub mod quantized_mixformer; pub mod quantized_moondream; pub mod quantized_mpt; +pub mod quantized_recurrent_gemma; pub mod quantized_rwkv_v5; pub mod quantized_rwkv_v6; pub mod quantized_stable_lm; diff --git a/candle-transformers/src/models/quantized_recurrent_gemma.rs b/candle-transformers/src/models/quantized_recurrent_gemma.rs new file mode 100644 index 00000000..c28064da --- /dev/null +++ b/candle-transformers/src/models/quantized_recurrent_gemma.rs @@ -0,0 +1,412 @@ +use crate::quantized_nn::{linear_b as linear, Embedding, Linear}; +pub use crate::quantized_var_builder::VarBuilder; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use std::sync::Arc; + +use crate::models::recurrent_gemma::{Config, Rglru, RmsNorm, RotaryEmbedding, TemporalBlockType}; + +fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get(size, "weight")?.dequantize(vb.device())?; + Ok(RmsNorm::from_weight(weight, eps)) +} + +#[derive(Debug, Clone)] +struct Mlp { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: candle_nn::Activation, +} + +impl Mlp { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let h = cfg.hidden_size; + let intermediate_size = cfg.intermediate_size / 2; + let gate_proj = linear(h, intermediate_size, true, vb.pp("gate_proj"))?; + let up_proj = linear(h, intermediate_size, true, vb.pp("up_proj"))?; + let down_proj = linear(intermediate_size, h, true, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_activation, + }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let gate = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; + (gate * xs.apply(&self.up_proj))?.apply(&self.down_proj) + } +} + +fn rglru(cfg: &Config, vb: VarBuilder) -> Result { + let h = cfg.hidden_size; + let lru_width = cfg.lru_width.unwrap_or(h); + let n_heads = cfg.num_attention_heads; + let block_width = lru_width / n_heads; + let recurrent_param = vb.get((lru_width,), "recurrent_param")?; + let input_gate_weight = vb.get((n_heads, block_width, block_width), "input_gate_weight")?; + let input_gate_bias = vb.get((n_heads, block_width), "input_gate_bias")?; + let recurrent_gate_weight = + vb.get((n_heads, block_width, block_width), "recurrent_gate_weight")?; + let recurrent_gate_bias = vb.get((n_heads, block_width), "recurrent_gate_bias")?; + Ok(Rglru { + recurrent_param: recurrent_param.dequantize(vb.device())?, + input_gate_bias: input_gate_bias.dequantize(vb.device())?, + input_gate_weight: input_gate_weight.dequantize(vb.device())?, + recurrent_gate_bias: recurrent_gate_bias.dequantize(vb.device())?, + recurrent_gate_weight: recurrent_gate_weight.dequantize(vb.device())?, + block_width, + n_heads, + recurrent_states: None, + }) +} + +#[derive(Debug, Clone)] +struct RecurrentBlock { + linear_y: Linear, + linear_x: Linear, + linear_out: Linear, + conv_1d: candle_nn::Conv1d, + conv1d_state: Option, + conv1d_width: usize, + rg_lru: Rglru, + act_fn: candle_nn::Activation, +} + +impl RecurrentBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let h = cfg.hidden_size; + let lru_width = cfg.lru_width.unwrap_or(h); + let linear_y = linear(h, lru_width, true, vb.pp("linear_y"))?; + let linear_x = linear(h, lru_width, true, vb.pp("linear_x"))?; + let linear_out = linear(lru_width, h, true, vb.pp("linear_out"))?; + + let conv_1d = { + let ws = vb + .get((lru_width, 1, cfg.conv1d_width), "conv_1d.weight")? + .dequantize(vb.device())?; + let bs = vb.get(lru_width, "conv_1d.bias")?.dequantize(vb.device())?; + let config = candle_nn::Conv1dConfig { + groups: lru_width, + padding: cfg.conv1d_width - 1, + ..Default::default() + }; + candle_nn::Conv1d::new(ws, Some(bs), config) + }; + let rg_lru = rglru(cfg, vb.pp("rg_lru"))?; + Ok(Self { + linear_y, + linear_x, + linear_out, + conv_1d, + conv1d_state: None, + conv1d_width: cfg.conv1d_width, + rg_lru, + act_fn: cfg.hidden_activation, + }) + } + + pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result { + let (_b_sz, seq_len, _) = xs.dims3()?; + + let y_branch = xs.apply(&self.linear_y)?.apply(&self.act_fn)?; + let x_branch = xs.apply(&self.linear_x)?.transpose(1, 2)?; + let x_branch = if pos == 0 { + let x_len = x_branch.dim(D::Minus1)?; + let pad = self.conv1d_width as i64 - x_len as i64 - 1; + let padded = match pad.cmp(&0) { + std::cmp::Ordering::Equal => x_branch.clone(), + std::cmp::Ordering::Less => { + let rev_pad = (-pad) as usize; + x_branch.narrow(D::Minus1, rev_pad, x_len - rev_pad)? + } + std::cmp::Ordering::Greater => { + x_branch.pad_with_zeros(D::Minus1, pad as usize, 0)? + } + }; + self.conv1d_state = Some(padded); + x_branch + .apply(&self.conv_1d)? + .narrow(D::Minus1, 0, seq_len)? + } else { + let conv_state = match self.conv1d_state.as_ref() { + None => candle::bail!("empty cache despite pos > 0"), + Some(s) => Tensor::cat(&[s, &x_branch], D::Minus1)?, + }; + let w = self.conv_1d.weight().i((.., 0, ..))?; + let x_branch = conv_state.broadcast_mul(&w)?.sum(D::Minus1)?; + let x_branch = match self.conv_1d.bias() { + None => x_branch, + Some(b) => x_branch.broadcast_add(b)?, + }; + let x_branch = x_branch.unsqueeze(D::Minus1)?; + self.conv1d_state = Some(conv_state.i((.., .., 1..))?); + x_branch + }; + let x_branch = x_branch.transpose(1, 2)?; + let x_branch = self.rg_lru.forward(&x_branch, pos)?; + (x_branch * y_branch)?.apply(&self.linear_out) + } +} + +#[derive(Debug, Clone)] +struct SdpaAttention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + n_heads: usize, + n_kv_heads: usize, + head_dim: usize, + hidden_size: usize, + kv_cache: Option<(Tensor, Tensor)>, + rotary_emb: Arc, +} + +impl SdpaAttention { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let h = cfg.hidden_size; + let n_heads = cfg.num_attention_heads; + let n_kv_heads = cfg.num_key_value_heads; + let hd = cfg.head_dim; + let q_proj = linear(h, n_heads * hd, cfg.attention_bias, vb.pp("q_proj"))?; + let k_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("k_proj"))?; + let v_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("v_proj"))?; + let o_proj = linear(n_heads * hd, h, true, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + n_heads, + n_kv_heads, + head_dim: hd, + hidden_size: h, + kv_cache: None, + rotary_emb, + }) + } + + fn repeat_kv(&self, x: Tensor) -> Result { + let n_rep = self.n_heads / self.n_kv_heads; + crate::utils::repeat_kv(x, n_rep) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + pos: usize, + ) -> Result { + let (bsz, q_len, _) = xs.dims3()?; + + let query_states = xs.apply(&self.q_proj)?; + let key_states = xs.apply(&self.k_proj)?; + let value_states = xs.apply(&self.v_proj)?; + + let query_states = query_states + .reshape((bsz, q_len, self.n_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((bsz, q_len, self.n_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((bsz, q_len, self.n_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let query_states = query_states.chunk(2, D::Minus1)?; + let key_states = key_states.chunk(2, D::Minus1)?; + let (query_rot, key_rot) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_states[0], &key_states[0], pos)?; + let query_states = Tensor::cat(&[&query_rot, &query_states[1]], D::Minus1)?.contiguous()?; + let key_states = Tensor::cat(&[&key_rot, &key_states[1]], D::Minus1)?.contiguous()?; + + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; + let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some((key_states.clone(), value_states.clone())); + + let key_states = self.repeat_kv(key_states)?; + let value_states = self.repeat_kv(value_states)?; + let xs = { + let att = (query_states.matmul(&key_states.t()?)? / (self.head_dim as f64).sqrt())?; + let att = if q_len == 1 { + att + } else { + match attention_mask { + None => att, + Some(mask) => att.broadcast_add(mask)?, + } + }; + let att = candle_nn::ops::softmax_last_dim(&att)?; + att.matmul(&value_states.contiguous()?)? + }; + + let xs = xs + .transpose(1, 2)? + .reshape((bsz, q_len, self.hidden_size))?; + self.o_proj.forward(&xs) + } +} + +#[derive(Debug, Clone)] +enum TemporalBlock { + Recurrent(RecurrentBlock), + Attention(SdpaAttention), +} + +impl TemporalBlock { + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + pos: usize, + ) -> Result { + match self { + Self::Recurrent(b) => b.forward(xs, pos), + Self::Attention(b) => b.forward(xs, attention_mask, pos), + } + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + temporal_pre_norm: RmsNorm, + channel_pre_norm: RmsNorm, + temporal_block: TemporalBlock, + mlp_block: Mlp, +} + +impl DecoderLayer { + fn new( + block_idx: usize, + rotary_emb: Arc, + cfg: &Config, + vb: VarBuilder, + ) -> Result { + let h = cfg.hidden_size; + let temporal_pre_norm = rms_norm(h, cfg.rms_norm_eps, vb.pp("temporal_pre_norm"))?; + let channel_pre_norm = rms_norm(h, cfg.rms_norm_eps, vb.pp("channel_pre_norm"))?; + let temporal_block = match cfg.block_types[block_idx % cfg.block_types.len()] { + TemporalBlockType::Recurrent => { + let block = RecurrentBlock::new(cfg, vb.pp("temporal_block"))?; + TemporalBlock::Recurrent(block) + } + TemporalBlockType::Attention => { + let block = SdpaAttention::new(rotary_emb, cfg, vb.pp("temporal_block"))?; + TemporalBlock::Attention(block) + } + }; + let mlp_block = Mlp::new(cfg, vb.pp("mlp_block"))?; + Ok(Self { + temporal_pre_norm, + channel_pre_norm, + temporal_block, + mlp_block, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + pos: usize, + ) -> Result { + let residual = xs; + let xs = xs.apply(&self.temporal_pre_norm)?; + let xs = self.temporal_block.forward(&xs, attention_mask, pos)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.channel_pre_norm)?.apply(&self.mlp_block)?; + xs + residual + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: Embedding, + layers: Vec, + final_norm: RmsNorm, + lm_head: Linear, + hidden_size: usize, + logits_soft_cap: f64, + device: Device, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embed_tokens = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(DType::F32, cfg, vb.device())?); + let vb_b = vb.pp("layers"); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + for idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(idx, rotary_emb.clone(), cfg, vb_b.pp(idx))?; + layers.push(layer) + } + let final_norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("final_norm"))?; + let lm_head = linear( + cfg.hidden_size, + cfg.vocab_size, + false, + vb.pp("embed_tokens"), + )?; + Ok(Self { + embed_tokens, + layers, + final_norm, + lm_head, + hidden_size: cfg.hidden_size, + logits_soft_cap: cfg.logits_soft_cap, + device: vb.device().clone(), + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result { + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(DType::F32) + } + + pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result { + let (b_size, seq_len) = xs.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, pos)?; + Some(mask) + }; + let xs = xs.apply(&self.embed_tokens)?; + let mut xs = (xs * (self.hidden_size as f64).sqrt())?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref(), pos)?; + } + let logits = xs + .narrow(1, seq_len - 1, 1)? + .apply(&self.final_norm)? + .apply(&self.lm_head)?; + let logits = ((logits / self.logits_soft_cap)?.tanh()? * self.logits_soft_cap)?; + Ok(logits) + } +} diff --git a/candle-transformers/src/models/recurrent_gemma.rs b/candle-transformers/src/models/recurrent_gemma.rs index 712cc347..24d2b7e3 100644 --- a/candle-transformers/src/models/recurrent_gemma.rs +++ b/candle-transformers/src/models/recurrent_gemma.rs @@ -40,16 +40,20 @@ fn default_max_seq_len() -> usize { } #[derive(Debug, Clone)] -struct RmsNorm { +pub(crate) struct RmsNorm { weight: Tensor, eps: f64, } impl RmsNorm { - fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { + pub(crate) fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { let weight = vb.get(dim, "weight")?; Ok(Self { weight, eps }) } + + pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self { + Self { weight, eps } + } } impl Module for RmsNorm { @@ -70,7 +74,7 @@ impl Module for RmsNorm { } #[derive(Debug, Clone)] -struct RotaryEmbedding { +pub(crate) struct RotaryEmbedding { sin: Tensor, cos: Tensor, } @@ -83,7 +87,7 @@ fn rotate_half(xs: &Tensor) -> Result { } impl RotaryEmbedding { - fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { if cfg.partial_rotary_factor != 0.5 { candle::bail!("partial-rotary-factor {} <> 0.5", cfg.partial_rotary_factor) } @@ -106,7 +110,7 @@ impl RotaryEmbedding { }) } - fn apply_rotary_emb_qkv( + pub(crate) fn apply_rotary_emb_qkv( &self, q: &Tensor, k: &Tensor, @@ -156,15 +160,15 @@ impl Module for Mlp { // Real-Gated Linear Recurrent Unit #[derive(Debug, Clone)] -struct Rglru { - recurrent_param: Tensor, - input_gate_weight: Tensor, - input_gate_bias: Tensor, - recurrent_gate_weight: Tensor, - recurrent_gate_bias: Tensor, - block_width: usize, - n_heads: usize, - recurrent_states: Option, +pub(crate) struct Rglru { + pub(crate) recurrent_param: Tensor, + pub(crate) input_gate_weight: Tensor, + pub(crate) input_gate_bias: Tensor, + pub(crate) recurrent_gate_weight: Tensor, + pub(crate) recurrent_gate_bias: Tensor, + pub(crate) block_width: usize, + pub(crate) n_heads: usize, + pub(crate) recurrent_states: Option, } fn baddbmm(a: &Tensor, b: &Tensor, c: &Tensor) -> Result { @@ -200,7 +204,7 @@ impl Rglru { } // https://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L303 - pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result { + pub(crate) fn forward(&mut self, xs: &Tensor, pos: usize) -> Result { let (b_sz, seq_len, lru_width) = xs.dims3()?; let pos = Tensor::arange(pos as u32, (pos + seq_len) as u32, xs.device())?; let reset = pos.eq(0u32)?.unsqueeze(1)?.unsqueeze(0)?; @@ -237,7 +241,7 @@ impl Rglru { reset.broadcast_add(&((1.0 - &reset)?.broadcast_mul(&(1.0 - a_square)?.sqrt()?))?)?; let normalized_x = (gated_inputs * multiplier.to_dtype(xs.dtype()))?; - let (hidden_states, recurrent_states) = self.rnn_scan( + let (hidden_states, recurrent_states) = rnn_scan( &normalized_x, &recurrent_gate, &reset, @@ -246,54 +250,53 @@ impl Rglru { self.recurrent_states = Some(recurrent_states); Ok(hidden_states) } +} - fn rnn_scan( - &self, - hidden_states: &Tensor, - recurrent_gate: &Tensor, - reset: &Tensor, - recurrent_states: Option<&Tensor>, - ) -> Result<(Tensor, Tensor)> { - let acc_dtype = DType::F32; - let dev = hidden_states.device(); - let in_dtype = hidden_states.dtype(); - let inv_reset = (1.0 - reset)?.to_dtype(recurrent_gate.dtype())?; - let recurrent_gate = recurrent_gate.broadcast_mul(&inv_reset)?; - let (c, r) = if hidden_states.dim(1)? == 1 { - match recurrent_states { - None => { - let next_state = hidden_states.i((.., 0))?.to_dtype(acc_dtype)?; - (hidden_states.clone(), next_state) - } - Some(recurrent_states) => { - let contextualized_states = - recurrent_gate.to_dtype(acc_dtype)? * recurrent_states.unsqueeze(1)?; - let contextualized_states = - (contextualized_states + hidden_states.to_dtype(acc_dtype)?)?; - let c = contextualized_states.to_dtype(in_dtype)?; - let l = contextualized_states.dim(1)?; - let r = contextualized_states.i((.., l - 1))?; - (c, r) - } +fn rnn_scan( + hidden_states: &Tensor, + recurrent_gate: &Tensor, + reset: &Tensor, + recurrent_states: Option<&Tensor>, +) -> Result<(Tensor, Tensor)> { + let acc_dtype = DType::F32; + let dev = hidden_states.device(); + let in_dtype = hidden_states.dtype(); + let inv_reset = (1.0 - reset)?.to_dtype(recurrent_gate.dtype())?; + let recurrent_gate = recurrent_gate.broadcast_mul(&inv_reset)?; + let (c, r) = if hidden_states.dim(1)? == 1 { + match recurrent_states { + None => { + let next_state = hidden_states.i((.., 0))?.to_dtype(acc_dtype)?; + (hidden_states.clone(), next_state) } - } else { - let mut recurrent_states = match recurrent_states { - None => Tensor::zeros(hidden_states.i((.., 0))?.shape(), acc_dtype, dev)?, - Some(r) => r.clone(), - }; - let mut contextualized_states = vec![]; - for t in 0..hidden_states.dim(1)? { - recurrent_states = - (recurrent_gate.i((.., t))?.to_dtype(acc_dtype)? * recurrent_states)?; - recurrent_states = - (recurrent_states + hidden_states.i((.., t))?.to_dtype(acc_dtype)?)?; - contextualized_states.push(recurrent_states.to_dtype(in_dtype)?) + Some(recurrent_states) => { + let contextualized_states = + recurrent_gate.to_dtype(acc_dtype)? * recurrent_states.unsqueeze(1)?; + let contextualized_states = + (contextualized_states + hidden_states.to_dtype(acc_dtype)?)?; + let c = contextualized_states.to_dtype(in_dtype)?; + let l = contextualized_states.dim(1)?; + let r = contextualized_states.i((.., l - 1))?; + (c, r) } - let contextualized_states = Tensor::stack(&contextualized_states, 1)?; - (contextualized_states, recurrent_states) + } + } else { + let mut recurrent_states = match recurrent_states { + None => Tensor::zeros(hidden_states.i((.., 0))?.shape(), acc_dtype, dev)?, + Some(r) => r.clone(), }; - Ok((c, r)) - } + let mut contextualized_states = vec![]; + for t in 0..hidden_states.dim(1)? { + recurrent_states = + (recurrent_gate.i((.., t))?.to_dtype(acc_dtype)? * recurrent_states)?; + recurrent_states = + (recurrent_states + hidden_states.i((.., t))?.to_dtype(acc_dtype)?)?; + contextualized_states.push(recurrent_states.to_dtype(in_dtype)?) + } + let contextualized_states = Tensor::stack(&contextualized_states, 1)?; + (contextualized_states, recurrent_states) + }; + Ok((c, r)) } #[derive(Debug, Clone)] diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs index a963e311..875a2b45 100644 --- a/candle-transformers/src/quantized_var_builder.rs +++ b/candle-transformers/src/quantized_var_builder.rs @@ -63,7 +63,7 @@ impl VarBuilder { let path = self.path(name); match self.data.get(&path) { None => { - candle::bail!("cannot find tensor {name}") + candle::bail!("cannot find tensor {path}") } Some(qtensor) => { let shape = s.into();