mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Move more safetensors bits to the shared module.
This commit is contained in:
@ -1,8 +1,5 @@
|
||||
use super::*;
|
||||
use candle::{Device, Result, Tensor};
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs::File;
|
||||
use candle::{safetensors::SafeTensors, Device, Result, Tensor};
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub struct VarBuilder<'a> {
|
||||
@ -30,8 +27,9 @@ impl<'a> VarBuilder<'a> {
|
||||
pub fn get(&self, tensor_name: &str) -> Result<Tensor> {
|
||||
// Unwrap or 0 just to let the proper error flow.
|
||||
let index = self.routing.get(tensor_name).unwrap_or(&0);
|
||||
let view = self.safetensors[*index].tensor(tensor_name).unwrap();
|
||||
candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE)
|
||||
self.safetensors[*index]
|
||||
.tensor(tensor_name, &self.device)?
|
||||
.to_dtype(DTYPE)
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,18 +105,12 @@ impl Llama {
|
||||
) -> Result<Self> {
|
||||
let handles: Vec<_> = filenames
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let file = File::open(f).unwrap();
|
||||
unsafe { MmapOptions::new().map(&file).unwrap() }
|
||||
})
|
||||
.collect();
|
||||
.map(candle::safetensors::MmapedFile::new)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| {
|
||||
let tensors = SafeTensors::deserialize(h).unwrap();
|
||||
tensors
|
||||
})
|
||||
.collect();
|
||||
.map(|h| h.deserialize())
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = VarBuilder::new(tensors, device.clone());
|
||||
|
||||
|
Reference in New Issue
Block a user