mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add the new gemma models. (#2023)
* Add the new gemma models. * Revert the lightning changes. * Support for the 1.1 models.
This commit is contained in:
@ -16,6 +16,22 @@ use candle_transformers::generation::LogitsProcessor;
|
|||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "2b")]
|
||||||
|
Base2B,
|
||||||
|
#[value(name = "7b")]
|
||||||
|
Base7B,
|
||||||
|
#[value(name = "2b-it")]
|
||||||
|
Instruct2B,
|
||||||
|
#[value(name = "7b-it")]
|
||||||
|
Instruct7B,
|
||||||
|
#[value(name = "1.1-2b-it")]
|
||||||
|
InstructV1_1_2B,
|
||||||
|
#[value(name = "1.1-7b-it")]
|
||||||
|
InstructV1_1_7B,
|
||||||
|
}
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
model: Model,
|
model: Model,
|
||||||
device: Device,
|
device: Device,
|
||||||
@ -165,6 +181,10 @@ struct Args {
|
|||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 64)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model to use.
|
||||||
|
#[arg(long, default_value = "2b")]
|
||||||
|
which: Which,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -196,14 +216,15 @@ fn main() -> Result<()> {
|
|||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let model_id = match &args.model_id {
|
let model_id = match &args.model_id {
|
||||||
Some(model_id) => match model_id.as_str() {
|
Some(model_id) => model_id.to_string(),
|
||||||
"7b-it" => "google/gemma-7b-it".to_string(),
|
None => match args.which {
|
||||||
"7b" => "google/gemma-7b".to_string(),
|
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
||||||
"2b-it" => "google/gemma-2b-it".to_string(),
|
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
||||||
"2b" => "google/gemma-2b".to_string(),
|
Which::Base2B => "google/gemma-2b".to_string(),
|
||||||
_ => model_id.to_string(),
|
Which::Base7B => "google/gemma-7b".to_string(),
|
||||||
|
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
||||||
|
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
||||||
},
|
},
|
||||||
None => "google/gemma-2b".to_string(),
|
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -11,6 +11,7 @@ fn default_max_position_embeddings() -> usize {
|
|||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub attention_bias: bool,
|
pub attention_bias: bool,
|
||||||
pub head_dim: usize,
|
pub head_dim: usize,
|
||||||
|
#[serde(alias = "hidden_activation")]
|
||||||
pub hidden_act: candle_nn::Activation,
|
pub hidden_act: candle_nn::Activation,
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
|
Reference in New Issue
Block a user