diff --git a/README.md b/README.md index 543b9ca8..06d3104e 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ We also provide a some command line based examples using state of the art models - [Falcon](./candle-examples/examples/falcon/): general LLM. - [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level - [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM -- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind. +- [Gemma v1 and v2](./candle-examples/examples/gemma/): 2b and 7b+/9b general LLMs from Google Deepmind. - [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b Griffin based models from Google that mix attention with a RNN like state. - [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b, @@ -208,7 +208,7 @@ If you have an addition to this list, please submit a pull request. - StarCoder, StarCoder2. - Phi 1, 1.5, 2, and 3. - Mamba, Minimal Mamba - - Gemma 2b and 7b. + - Gemma v1 2b and 7b+, v2 2b and 9b. - Mistral 7b v0.1. - Mixtral 8x7b v0.1. - StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B. diff --git a/candle-examples/examples/gemma/README.md b/candle-examples/examples/gemma/README.md index 5d77c7a4..674333aa 100644 --- a/candle-examples/examples/gemma/README.md +++ b/candle-examples/examples/gemma/README.md @@ -1,27 +1,27 @@ # candle-gemma: 2b and 7b LLMs from Google DeepMind [Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open -models published by Google Deepmind with a 2b and a 7b variant. - -In order to use the example below, you have to accept the license on the -[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up -your access token via the [HuggingFace cli login -command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login). +models published by Google Deepmind with a 2b and a 7b variant for the first +version, and a 2b and a 9b variant for v2. ## Running the example ```bash -$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)" -fn count_primes(max_n: usize) -> usize { - let mut primes = vec![true; max_n]; - for i in 2..=max_n { - if primes[i] { - for j in i * i..max_n { - primes[j] = false; - } - } - } - primes.len() -} +$ cargo run --example gemma --features cuda -r -- \ + --prompt "Here is a proof that square root of 2 is not rational: " + +Here is a proof that square root of 2 is not rational: + +Let us assume it to be rational. Then, we can write √2 = p/q where q ≠ 0 and p and q are integers with no common factors other than 1. Squaring both sides gives us (p/q)^2 = 2 or p^2/q^2 = 2. This implies that p^2 is divisible by 2, which means that p must be even. Let us write p = 2m where m is an integer. Substituting this in the above equation we get: + +(p^2)/q^2 = 2 or (4m^2)/q^2 = 2 or q^2/2m^2 = 1 which implies that q^2 must be divisible by 2, and hence q is even. This contradicts our assumption that p and q have no common factors other than 1. Hence we conclude that √2 cannot be rational. ``` +## Access restrictions + +In order to use the v1 examples, you have to accept the license on the +[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up +your access token via the [HuggingFace cli login +command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login). + + diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs index 31c55618..b11d7710 100644 --- a/candle-examples/examples/gemma/main.rs +++ b/candle-examples/examples/gemma/main.rs @@ -7,7 +7,8 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -use candle_transformers::models::gemma::{Config, Model}; +use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; +use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -38,6 +39,46 @@ enum Which { CodeInstruct2B, #[value(name = "code-7b-it")] CodeInstruct7B, + #[value(name = "2-2b")] + BaseV2_2B, + #[value(name = "2-2b-it")] + InstructV2_2B, + #[value(name = "2-9b")] + BaseV2_9B, + #[value(name = "2-9b-it")] + InstructV2_9B, +} + +impl Which { + fn is_v1(&self) -> bool { + match self { + Self::Base2B + | Self::Base7B + | Self::Instruct2B + | Self::Instruct7B + | Self::InstructV1_1_2B + | Self::InstructV1_1_7B + | Self::CodeBase2B + | Self::CodeBase7B + | Self::CodeInstruct2B + | Self::CodeInstruct7B => true, + Self::BaseV2_2B | Self::InstructV2_2B | Self::BaseV2_9B | Self::InstructV2_9B => false, + } + } +} + +enum Model { + V1(Model1), + V2(Model2), +} + +impl Model { + fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle::Result { + match self { + Self::V1(m) => m.forward(input_ids, pos), + Self::V2(m) => m.forward(input_ids, pos), + } + } } struct TextGeneration { @@ -191,7 +232,7 @@ struct Args { repeat_last_n: usize, /// The model to use. - #[arg(long, default_value = "2b")] + #[arg(long, default_value = "2-2b")] which: Which, #[arg(long)] @@ -239,6 +280,10 @@ fn main() -> Result<()> { Which::CodeBase7B => "google/codegemma-7b".to_string(), Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(), Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(), + Which::BaseV2_2B => "google/gemma-2-2b".to_string(), + Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(), + Which::BaseV2_9B => "google/gemma-2-9b".to_string(), + Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(), }, }; let repo = api.repo(Repo::with_revision( @@ -263,7 +308,6 @@ fn main() -> Result<()> { }; println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let start = std::time::Instant::now(); let device = candle_examples::device(args.cpu)?; @@ -273,7 +317,15 @@ fn main() -> Result<()> { DType::F32 }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - let model = Model::new(args.use_flash_attn, &config, vb)?; + let model = if args.which.is_v1() { + let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; + let model = Model1::new(args.use_flash_attn, &config, vb)?; + Model::V1(model) + } else { + let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; + let model = Model2::new(args.use_flash_attn, &config, vb)?; + Model::V2(model) + }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-transformers/src/models/gemma2.rs b/candle-transformers/src/models/gemma2.rs new file mode 100644 index 00000000..f0d65047 --- /dev/null +++ b/candle-transformers/src/models/gemma2.rs @@ -0,0 +1,449 @@ +use std::sync::Arc; + +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder}; + +fn default_max_position_embeddings() -> usize { + 4096 +} + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct Config { + pub attention_bias: bool, + pub head_dim: usize, + pub hidden_activation: Activation, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_attention_heads: usize, + pub num_hidden_layers: usize, + pub num_key_value_heads: usize, + pub rms_norm_eps: f64, + pub rope_theta: f64, + pub vocab_size: usize, + pub final_logit_softcapping: Option, + pub attn_logit_softcapping: Option, + pub query_pre_attn_scalar: usize, + // TODO: Handle the sliding window in the attention mask. + pub sliding_window: Option, + + #[serde(default = "default_max_position_embeddings")] + pub max_position_embeddings: usize, +} + +#[derive(Debug, Clone)] +struct RmsNorm { + weight: Tensor, + eps: f64, +} + +impl RmsNorm { + fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get(dim, "weight")?; + Ok(Self { weight, eps }) + } +} + +impl Module for RmsNorm { + fn forward(&self, x: &Tensor) -> Result { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + x_normed + .to_dtype(x_dtype)? + .broadcast_mul(&(&self.weight + 1.0)?) + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.head_dim; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: candle_nn::Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?; + let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?; + let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_activation, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + attn_logit_softcapping: Option, + rotary_emb: Arc, + kv_cache: Option<(Tensor, Tensor)>, + use_flash_attn: bool, +} + +impl Attention { + fn new( + rotary_emb: Arc, + use_flash_attn: bool, + cfg: &Config, + vb: VarBuilder, + ) -> Result { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = cfg.head_dim; + let bias = cfg.attention_bias; + let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?; + let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?; + let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?; + let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + attn_logit_softcapping: cfg.attn_logit_softcapping, + rotary_emb, + kv_cache: None, + use_flash_attn, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (query_states, key_states) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; + let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some((key_states.clone(), value_states.clone())); + + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; + + let attn_output = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = query_states.transpose(1, 2)?; + let k = key_states.transpose(1, 2)?; + let v = value_states.transpose(1, 2)?; + let scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)? + } else { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match self.attn_logit_softcapping { + None => attn_weights, + Some(sc) => ((attn_weights / sc)?.tanh()? * sc)?, + }; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? + }; + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, ()))? + .apply(&self.o_proj) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("compile with '--features flash-attn'") +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + mlp: MLP, + input_layernorm: RmsNorm, + pre_feedforward_layernorm: RmsNorm, + post_feedforward_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new( + rotary_emb: Arc, + use_flash_attn: bool, + cfg: &Config, + vb: VarBuilder, + ) -> Result { + let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let pre_feedforward_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("pre_feedforward_layernorm"), + )?; + let post_feedforward_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_feedforward_layernorm"), + )?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + pre_feedforward_layernorm, + post_feedforward_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; + let xs = xs.apply(&self.post_attention_layernorm)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.pre_feedforward_layernorm)?; + let xs = xs.apply(&self.mlp)?; + let xs = xs.apply(&self.post_feedforward_layernorm)?; + residual + xs + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache() + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + lm_head: Linear, + final_logit_softcapping: Option, + device: Device, + dtype: DType, + hidden_size: usize, + sliding_window: Option, +} + +impl Model { + pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("model"); + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = + DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let lm_head = Linear::new(embed_tokens.embeddings().clone(), None); + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + final_logit_softcapping: cfg.final_logit_softcapping, + device: vb.device().clone(), + dtype: vb.dtype(), + hidden_size: cfg.hidden_size, + sliding_window: cfg.sliding_window, + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result { + let mask: Vec<_> = match self.sliding_window { + None => (0..tgt_len) + .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .collect(), + Some(sliding_window) => (0..tgt_len) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect(), + }; + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(self.dtype) + } + + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { + let (b_size, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + Some(mask) + }; + let xs = self.embed_tokens.forward(input_ids)?; + let mut xs = (xs * (self.hidden_size as f64).sqrt())?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + } + let logits = xs + .narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head)?; + let logits = match self.final_logit_softcapping { + None => logits, + Some(sc) => ((logits / sc)?.tanh()? * sc)?, + }; + + Ok(logits) + } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 7baaaf72..16952c6a 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -20,6 +20,7 @@ pub mod eva2; pub mod falcon; pub mod flux; pub mod gemma; +pub mod gemma2; pub mod glm4; pub mod hiera; pub mod jina_bert;