mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the code-gemma models. (#2038)
* Add the code-gemma models. * Tweak to the gemma config.
This commit is contained in:
@ -30,6 +30,14 @@ enum Which {
|
|||||||
InstructV1_1_2B,
|
InstructV1_1_2B,
|
||||||
#[value(name = "1.1-7b-it")]
|
#[value(name = "1.1-7b-it")]
|
||||||
InstructV1_1_7B,
|
InstructV1_1_7B,
|
||||||
|
#[value(name = "code-2b")]
|
||||||
|
CodeBase2B,
|
||||||
|
#[value(name = "code-7b")]
|
||||||
|
CodeBase7B,
|
||||||
|
#[value(name = "code-2b-it")]
|
||||||
|
CodeInstruct2B,
|
||||||
|
#[value(name = "code-7b-it")]
|
||||||
|
CodeInstruct7B,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
@ -224,6 +232,10 @@ fn main() -> Result<()> {
|
|||||||
Which::Base7B => "google/gemma-7b".to_string(),
|
Which::Base7B => "google/gemma-7b".to_string(),
|
||||||
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
||||||
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
||||||
|
Which::CodeBase2B => "google/codegemma-2b".to_string(),
|
||||||
|
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||||
|
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||||
|
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::{linear_b as linear, Linear, VarBuilder};
|
use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
|
||||||
|
|
||||||
fn default_max_position_embeddings() -> usize {
|
fn default_max_position_embeddings() -> usize {
|
||||||
4096
|
4096
|
||||||
@ -11,8 +11,9 @@ fn default_max_position_embeddings() -> usize {
|
|||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub attention_bias: bool,
|
pub attention_bias: bool,
|
||||||
pub head_dim: usize,
|
pub head_dim: usize,
|
||||||
#[serde(alias = "hidden_activation")]
|
// The code gemma configs include both hidden_act and hidden_activation.
|
||||||
pub hidden_act: candle_nn::Activation,
|
pub hidden_act: Option<Activation>,
|
||||||
|
pub hidden_activation: Option<Activation>,
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
pub num_attention_heads: usize,
|
pub num_attention_heads: usize,
|
||||||
@ -26,6 +27,16 @@ pub struct Config {
|
|||||||
pub max_position_embeddings: usize,
|
pub max_position_embeddings: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
fn hidden_act(&self) -> Result<Activation> {
|
||||||
|
match (self.hidden_act, self.hidden_activation) {
|
||||||
|
(None, Some(act)) | (Some(act), None) => Ok(act),
|
||||||
|
(Some(_), Some(_)) => candle::bail!("both hidden_act and hidden_activation are set"),
|
||||||
|
(None, None) => candle::bail!("none of hidden_act and hidden_activation are set"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
@ -127,7 +138,7 @@ impl MLP {
|
|||||||
gate_proj,
|
gate_proj,
|
||||||
up_proj,
|
up_proj,
|
||||||
down_proj,
|
down_proj,
|
||||||
act_fn: cfg.hidden_act,
|
act_fn: cfg.hidden_act()?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user