Add the new gemma models. (#2023)

* Add the new gemma models.

* Revert the lightning changes.

* Support for the 1.1 models.
This commit is contained in:
Laurent Mazare
2024-04-06 21:25:38 +02:00
committed by GitHub
parent 9fd52b3b71
commit 33c9b66554
2 changed files with 29 additions and 7 deletions

View File

@ -16,6 +16,22 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
}
struct TextGeneration { struct TextGeneration {
model: Model, model: Model,
device: Device, device: Device,
@ -165,6 +181,10 @@ struct Args {
/// The context size to consider for the repeat penalty. /// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)] #[arg(long, default_value_t = 64)]
repeat_last_n: usize, repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
@ -196,14 +216,15 @@ fn main() -> Result<()> {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let api = Api::new()?; let api = Api::new()?;
let model_id = match &args.model_id { let model_id = match &args.model_id {
Some(model_id) => match model_id.as_str() { Some(model_id) => model_id.to_string(),
"7b-it" => "google/gemma-7b-it".to_string(), None => match args.which {
"7b" => "google/gemma-7b".to_string(), Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
"2b-it" => "google/gemma-2b-it".to_string(), Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
"2b" => "google/gemma-2b".to_string(), Which::Base2B => "google/gemma-2b".to_string(),
_ => model_id.to_string(), Which::Base7B => "google/gemma-7b".to_string(),
Which::Instruct2B => "google/gemma-2b-it".to_string(),
Which::Instruct7B => "google/gemma-7b-it".to_string(),
}, },
None => "google/gemma-2b".to_string(),
}; };
let repo = api.repo(Repo::with_revision( let repo = api.repo(Repo::with_revision(
model_id, model_id,

View File

@ -11,6 +11,7 @@ fn default_max_position_embeddings() -> usize {
pub struct Config { pub struct Config {
pub attention_bias: bool, pub attention_bias: bool,
pub head_dim: usize, pub head_dim: usize,
#[serde(alias = "hidden_activation")]
pub hidden_act: candle_nn::Activation, pub hidden_act: candle_nn::Activation,
pub hidden_size: usize, pub hidden_size: usize,
pub intermediate_size: usize, pub intermediate_size: usize,