mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Support safetensors weights in llama2.c inference. (#317)
This commit is contained in:
@ -27,7 +27,7 @@ struct InferenceCmd {
|
|||||||
#[arg(long, default_value = "")]
|
#[arg(long, default_value = "")]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
/// Config file in binary format.
|
/// Config file in binary or safetensors format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
config: Option<String>,
|
config: Option<String>,
|
||||||
|
|
||||||
@ -225,11 +225,22 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(common_args.cpu)?;
|
let device = candle_examples::device(common_args.cpu)?;
|
||||||
|
|
||||||
let mut file = std::fs::File::open(config_path)?;
|
let is_safetensors = config_path
|
||||||
let config = Config::from_reader(&mut file)?;
|
.extension()
|
||||||
println!("{config:?}");
|
.map_or(false, |v| v == "safetensors");
|
||||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
let (vb, config) = if is_safetensors {
|
||||||
let vb = weights.var_builder(&config, &device)?;
|
let config = Config::tiny();
|
||||||
|
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||||
|
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||||
|
(vb, config)
|
||||||
|
} else {
|
||||||
|
let mut file = std::fs::File::open(config_path)?;
|
||||||
|
let config = Config::from_reader(&mut file)?;
|
||||||
|
println!("{config:?}");
|
||||||
|
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||||
|
let vb = weights.var_builder(&config, &device)?;
|
||||||
|
(vb, config)
|
||||||
|
};
|
||||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
let model = Llama::load(vb, &cache, config)?;
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ impl TransformerWeights {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
|
pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
|
||||||
let mut ws = std::collections::HashMap::new();
|
let mut ws = std::collections::HashMap::new();
|
||||||
let mut insert = |name: &str, t: Tensor| {
|
let mut insert = |name: &str, t: Tensor| {
|
||||||
ws.insert(name.to_string(), t);
|
ws.insert(name.to_string(), t);
|
||||||
|
Reference in New Issue
Block a user