Adapt more examples to the updated safetensor api. (#947)

* Simplify the safetensor usage.

* Convert more examples.

* Move more examples.

* Adapt stable-diffusion.
This commit is contained in:
Laurent Mazare
2023-09-23 21:26:03 +01:00
committed by GitHub
parent 890d069092
commit bb3471ea31
14 changed files with 31 additions and 84 deletions

View File

@ -177,21 +177,12 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let weights = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
.collect::<Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let dtype = if args.use_f32 {
DType::F32
} else {
DType::BF16
};
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let config = Config::falcon7b();
config.validate()?;
let model = Falcon::load(vb, config)?;