Adapt more examples to the updated safetensor api. (#947)

* Simplify the safetensor usage.

* Convert more examples.

* Move more examples.

* Adapt stable-diffusion.
This commit is contained in:
Laurent Mazare
2023-09-23 21:26:03 +01:00
committed by GitHub
parent 890d069092
commit bb3471ea31
14 changed files with 31 additions and 84 deletions

View File

@ -138,18 +138,9 @@ fn main() -> Result<()> {
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
.collect::<Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let config = Config::starcoder_1b();
let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());