From 10b2e693ff5ba8f3f9247d8a0780092500e2ca9a Mon Sep 17 00:00:00 2001 From: Laurent Date: Sun, 3 Nov 2024 16:42:02 +0100 Subject: [PATCH] Add the SmolLM2 models. --- candle-examples/examples/llama/main.rs | 57 +++++++++++++++++++------- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index cc99b6c1..6b154308 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -43,6 +43,18 @@ enum Which { Solar10_7B, #[value(name = "tiny-llama-1.1b-chat")] TinyLlama1_1BChat, + #[value(name = "SmoLM2-1B")] + SmolLM2_1B, + #[value(name = "SmoLM2-1B-Instruct")] + SmolLM2_1BInstruct, + #[value(name = "SmoLM2-360M")] + SmolLM2_360M, + #[value(name = "SmoLM2-360M-Instruct")] + SmolLM2_360MInstruct, + #[value(name = "SmoLM2-135M")] + SmolLM2_135M, + #[value(name = "SmoLM2-135M-Instruct")] + SmolLM2_135MInstruct, } #[derive(Parser, Debug)] @@ -134,19 +146,28 @@ fn main() -> Result<()> { }; let (llama, tokenizer_filename, mut cache, config) = { let api = Api::new()?; - let model_id = args.model_id.unwrap_or_else(|| match args.which { - Which::V1 => "Narsil/amall-7b".to_string(), - Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), - Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(), - Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(), - Which::V31 => "meta-llama/Llama-3.1-8B".to_string(), - Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct".to_string(), - Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(), - Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(), - Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(), - Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct".to_string(), - Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), - Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), + let model_id = args.model_id.unwrap_or_else(|| { + let str = match args.which { + Which::V1 => "Narsil/amall-7b", + Which::V2 => "meta-llama/Llama-2-7b-hf", + Which::V3 => "meta-llama/Meta-Llama-3-8B", + Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct", + Which::V31 => "meta-llama/Llama-3.1-8B", + Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct", + Which::V32_1b => "meta-llama/Llama-3.2-1B", + Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct", + Which::V32_3b => "meta-llama/Llama-3.2-3B", + Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct", + Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0", + Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + Which::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M", + Which::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct", + Which::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M", + Which::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct", + Which::SmolLM2_1B => "HuggingFaceTB/SmolLM2-1.7B", + Which::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct", + }; + str.to_string() }); println!("loading the model weights from {model_id}"); let revision = args.revision.unwrap_or("main".to_string()); @@ -169,7 +190,15 @@ fn main() -> Result<()> { | Which::Solar10_7B => { candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? } - Which::V32_1b | Which::V32_1bInstruct | Which::TinyLlama1_1BChat => { + Which::SmolLM2_360M + | Which::SmolLM2_360MInstruct + | Which::SmolLM2_135M + | Which::SmolLM2_135MInstruct + | Which::SmolLM2_1B + | Which::SmolLM2_1BInstruct + | Which::V32_1b + | Which::V32_1bInstruct + | Which::TinyLlama1_1BChat => { vec![api.get("model.safetensors")?] } };