Add Gemma 3 1b IT toe Gemma examples (#2809)

- Updates the Gemma example to include Gemma 3 1b instruction tuned.
This commit is contained in:
André Cipriani Bandarra
2025-03-16 16:00:48 +00:00
committed by GitHub
parent 468d1d525f
commit cbf5fc80c2

View File

@ -50,6 +50,8 @@ enum Which {
InstructV2_9B, InstructV2_9B,
#[value(name = "3-1b")] #[value(name = "3-1b")]
BaseV3_1B, BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
} }
enum Model { enum Model {
@ -272,6 +274,7 @@ fn main() -> Result<()> {
Which::BaseV2_9B => "google/gemma-2-9b".to_string(), Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(), Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(), Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
}, },
}; };
let repo = api.repo(Repo::with_revision( let repo = api.repo(Repo::with_revision(
@ -292,13 +295,10 @@ fn main() -> Result<()> {
.split(',') .split(',')
.map(std::path::PathBuf::from) .map(std::path::PathBuf::from)
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
None => { None => match args.which {
if args.which == Which::BaseV3_1B { Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
vec![repo.get("model.safetensors")?] _ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
} else { },
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
}
}; };
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
@ -331,7 +331,7 @@ fn main() -> Result<()> {
let model = Model2::new(args.use_flash_attn, &config, vb)?; let model = Model2::new(args.use_flash_attn, &config, vb)?;
Model::V2(model) Model::V2(model)
} }
Which::BaseV3_1B => { Which::BaseV3_1B | Which::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(args.use_flash_attn, &config, vb)?; let model = Model3::new(args.use_flash_attn, &config, vb)?;
Model::V3(model) Model::V3(model)