Support safetensors weights in llama2.c inference. (#317)

This commit is contained in:
Laurent Mazare
2023-08-03 11:10:58 +01:00
committed by GitHub
parent 74845a4dcd
commit a79286885c
2 changed files with 18 additions and 7 deletions

View File

@ -27,7 +27,7 @@ struct InferenceCmd {
#[arg(long, default_value = "")]
prompt: String,
/// Config file in binary format.
/// Config file in binary or safetensors format.
#[arg(long)]
config: Option<String>,
@ -225,11 +225,22 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?;
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let is_safetensors = config_path
.extension()
.map_or(false, |v| v == "safetensors");
let (vb, config) = if is_safetensors {
let config = Config::tiny();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
(vb, config)
} else {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
(vb, config)
};
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;