Add the code-gemma models. (#2038)

* Add the code-gemma models.

* Tweak to the gemma config.
This commit is contained in:
Laurent Mazare
2024-04-10 21:19:21 +02:00
committed by GitHub
parent b81ecf712d
commit a0460cd2b1
2 changed files with 27 additions and 4 deletions

View File

@ -1,7 +1,7 @@
use std::sync::Arc;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b as linear, Linear, VarBuilder};
use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
fn default_max_position_embeddings() -> usize {
4096
@ -11,8 +11,9 @@ fn default_max_position_embeddings() -> usize {
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
#[serde(alias = "hidden_activation")]
pub hidden_act: candle_nn::Activation,
// The code gemma configs include both hidden_act and hidden_activation.
pub hidden_act: Option<Activation>,
pub hidden_activation: Option<Activation>,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_attention_heads: usize,
@ -26,6 +27,16 @@ pub struct Config {
pub max_position_embeddings: usize,
}
impl Config {
fn hidden_act(&self) -> Result<Activation> {
match (self.hidden_act, self.hidden_activation) {
(None, Some(act)) | (Some(act), None) => Ok(act),
(Some(_), Some(_)) => candle::bail!("both hidden_act and hidden_activation are set"),
(None, None) => candle::bail!("none of hidden_act and hidden_activation are set"),
}
}
}
#[derive(Debug, Clone)]
struct RmsNorm {
weight: Tensor,
@ -127,7 +138,7 @@ impl MLP {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
act_fn: cfg.hidden_act()?,
})
}
}