Self-contained safetensors for the multiprocess llama example. (#950)

This commit is contained in:
Laurent Mazare
2023-09-24 06:54:49 +01:00
committed by GitHub
parent 7edd755756
commit bcb0ed8f1c
3 changed files with 22 additions and 42 deletions

View File

@ -205,16 +205,9 @@ fn main() -> Result<()> {
let cache = model::Cache::new(dtype, &config, &device)?;
println!("building the model");
let handles = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
.collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
let vb = unsafe {
candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)?
};
let llama = Llama::load(vb, &cache, &config, comm)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;