From a0460cd2b13a396ff8545dc1bbffa741f2ec3d79 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 10 Apr 2024 21:19:21 +0200 Subject: [PATCH] Add the code-gemma models. (#2038) * Add the code-gemma models. * Tweak to the gemma config. --- candle-examples/examples/gemma/main.rs | 12 ++++++++++++ candle-transformers/src/models/gemma.rs | 19 +++++++++++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs index 0e37f5cd..a5f7d591 100644 --- a/candle-examples/examples/gemma/main.rs +++ b/candle-examples/examples/gemma/main.rs @@ -30,6 +30,14 @@ enum Which { InstructV1_1_2B, #[value(name = "1.1-7b-it")] InstructV1_1_7B, + #[value(name = "code-2b")] + CodeBase2B, + #[value(name = "code-7b")] + CodeBase7B, + #[value(name = "code-2b-it")] + CodeInstruct2B, + #[value(name = "code-7b-it")] + CodeInstruct7B, } struct TextGeneration { @@ -224,6 +232,10 @@ fn main() -> Result<()> { Which::Base7B => "google/gemma-7b".to_string(), Which::Instruct2B => "google/gemma-2b-it".to_string(), Which::Instruct7B => "google/gemma-7b-it".to_string(), + Which::CodeBase2B => "google/codegemma-2b".to_string(), + Which::CodeBase7B => "google/codegemma-7b".to_string(), + Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(), + Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(), }, }; let repo = api.repo(Repo::with_revision( diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index ab2a9582..15e4dccb 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use candle::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{linear_b as linear, Linear, VarBuilder}; +use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder}; fn default_max_position_embeddings() -> usize { 4096 @@ -11,8 +11,9 @@ fn default_max_position_embeddings() -> usize { pub struct Config { pub attention_bias: bool, pub head_dim: usize, - #[serde(alias = "hidden_activation")] - pub hidden_act: candle_nn::Activation, + // The code gemma configs include both hidden_act and hidden_activation. + pub hidden_act: Option, + pub hidden_activation: Option, pub hidden_size: usize, pub intermediate_size: usize, pub num_attention_heads: usize, @@ -26,6 +27,16 @@ pub struct Config { pub max_position_embeddings: usize, } +impl Config { + fn hidden_act(&self) -> Result { + match (self.hidden_act, self.hidden_activation) { + (None, Some(act)) | (Some(act), None) => Ok(act), + (Some(_), Some(_)) => candle::bail!("both hidden_act and hidden_activation are set"), + (None, None) => candle::bail!("none of hidden_act and hidden_activation are set"), + } + } +} + #[derive(Debug, Clone)] struct RmsNorm { weight: Tensor, @@ -127,7 +138,7 @@ impl MLP { gate_proj, up_proj, down_proj, - act_fn: cfg.hidden_act, + act_fn: cfg.hidden_act()?, }) } }