mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Adapt more examples to the updated safetensor api. (#947)
* Simplify the safetensor usage. * Convert more examples. * Move more examples. * Adapt stable-diffusion.
This commit is contained in:
@ -138,18 +138,9 @@ fn main() -> Result<()> {
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let weights = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let config = Config::starcoder_1b();
|
||||
let model = GPTBigCode::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
Reference in New Issue
Block a user