diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs index 9ee94a80..f6247c02 100644 --- a/candle-examples/examples/gemma/main.rs +++ b/candle-examples/examples/gemma/main.rs @@ -50,6 +50,8 @@ enum Which { InstructV2_9B, #[value(name = "3-1b")] BaseV3_1B, + #[value(name = "3-1b-it")] + InstructV3_1B, } enum Model { @@ -272,6 +274,7 @@ fn main() -> Result<()> { Which::BaseV2_9B => "google/gemma-2-9b".to_string(), Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(), Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(), + Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(), }, }; let repo = api.repo(Repo::with_revision( @@ -292,13 +295,10 @@ fn main() -> Result<()> { .split(',') .map(std::path::PathBuf::from) .collect::>(), - None => { - if args.which == Which::BaseV3_1B { - vec![repo.get("model.safetensors")?] - } else { - candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? - } - } + None => match args.which { + Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?], + _ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }, }; println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; @@ -331,7 +331,7 @@ fn main() -> Result<()> { let model = Model2::new(args.use_flash_attn, &config, vb)?; Model::V2(model) } - Which::BaseV3_1B => { + Which::BaseV3_1B | Which::InstructV3_1B => { let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let model = Model3::new(args.use_flash_attn, &config, vb)?; Model::V3(model)