From a79286885caaf453821dcc8a1328eba0cf573092 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 3 Aug 2023 11:10:58 +0100 Subject: [PATCH] Support safetensors weights in llama2.c inference. (#317) --- candle-examples/examples/llama2-c/main.rs | 23 +++++++++++++++----- candle-examples/examples/llama2-c/weights.rs | 2 +- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 8b64fdd2..612dc358 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -27,7 +27,7 @@ struct InferenceCmd { #[arg(long, default_value = "")] prompt: String, - /// Config file in binary format. + /// Config file in binary or safetensors format. #[arg(long)] config: Option, @@ -225,11 +225,22 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let device = candle_examples::device(common_args.cpu)?; - let mut file = std::fs::File::open(config_path)?; - let config = Config::from_reader(&mut file)?; - println!("{config:?}"); - let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; - let vb = weights.var_builder(&config, &device)?; + let is_safetensors = config_path + .extension() + .map_or(false, |v| v == "safetensors"); + let (vb, config) = if is_safetensors { + let config = Config::tiny(); + let tensors = candle::safetensors::load(config_path, &device)?; + let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); + (vb, config) + } else { + let mut file = std::fs::File::open(config_path)?; + let config = Config::from_reader(&mut file)?; + println!("{config:?}"); + let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; + let vb = weights.var_builder(&config, &device)?; + (vb, config) + }; let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let model = Llama::load(vb, &cache, config)?; diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-examples/examples/llama2-c/weights.rs index ae1fd6d9..2daed057 100644 --- a/candle-examples/examples/llama2-c/weights.rs +++ b/candle-examples/examples/llama2-c/weights.rs @@ -104,7 +104,7 @@ impl TransformerWeights { }) } - pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result { + pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result> { let mut ws = std::collections::HashMap::new(); let mut insert = |name: &str, t: Tensor| { ws.insert(name.to_string(), t);