mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add Gemma 3 1b IT toe Gemma examples (#2809)
- Updates the Gemma example to include Gemma 3 1b instruction tuned.
This commit is contained in:

committed by
GitHub

parent
468d1d525f
commit
cbf5fc80c2
@ -50,6 +50,8 @@ enum Which {
|
|||||||
InstructV2_9B,
|
InstructV2_9B,
|
||||||
#[value(name = "3-1b")]
|
#[value(name = "3-1b")]
|
||||||
BaseV3_1B,
|
BaseV3_1B,
|
||||||
|
#[value(name = "3-1b-it")]
|
||||||
|
InstructV3_1B,
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Model {
|
enum Model {
|
||||||
@ -272,6 +274,7 @@ fn main() -> Result<()> {
|
|||||||
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
||||||
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||||
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
||||||
|
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
@ -292,13 +295,10 @@ fn main() -> Result<()> {
|
|||||||
.split(',')
|
.split(',')
|
||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => {
|
None => match args.which {
|
||||||
if args.which == Which::BaseV3_1B {
|
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||||
vec![repo.get("model.safetensors")?]
|
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
} else {
|
},
|
||||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
@ -331,7 +331,7 @@ fn main() -> Result<()> {
|
|||||||
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
||||||
Model::V2(model)
|
Model::V2(model)
|
||||||
}
|
}
|
||||||
Which::BaseV3_1B => {
|
Which::BaseV3_1B | Which::InstructV3_1B => {
|
||||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
let model = Model3::new(args.use_flash_attn, &config, vb)?;
|
let model = Model3::new(args.use_flash_attn, &config, vb)?;
|
||||||
Model::V3(model)
|
Model::V3(model)
|
||||||
|
Reference in New Issue
Block a user