VarBuilder cleanup (#627)

* VarBuilder cleanup.

* Implement the basic varbuilders.

* Add the sharded code.

* Proper support for tensor sharding.
This commit is contained in:
Laurent Mazare
2023-08-27 18:03:26 +01:00
committed by GitHub
parent be471d50ab
commit 4c338b0cd9
12 changed files with 409 additions and 291 deletions

View File

@ -13,7 +13,6 @@ use anyhow::{bail, Error as E, Result};
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
@ -211,7 +210,7 @@ fn main() -> Result<()> {
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
let llama = Llama::load(vb, &cache, &config, comm)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;