mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Move more safetensors bits to the shared module.
This commit is contained in:
@ -25,3 +25,46 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
dt => todo!("Unhandled dtype {dt:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the
|
||||
// SafeTensor bits in the same struct and avoid having the final users calling two methods.
|
||||
// We could try using the ouroboros crate or equivalent for this at some point.
|
||||
// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors,
|
||||
// dtypes, etc
|
||||
pub struct SafeTensors<'a>(st::SafeTensors<'a>);
|
||||
|
||||
pub struct MmapedFile(memmap2::Mmap);
|
||||
|
||||
impl MmapedFile {
|
||||
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
|
||||
Ok(Self(mmap))
|
||||
}
|
||||
|
||||
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
|
||||
let st = safetensors::SafeTensors::deserialize(&self.0)?;
|
||||
Ok(SafeTensors(st))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> SafeTensors<'a> {
|
||||
pub fn tensor(&self, name: &str, device: &Device) -> Result<Tensor> {
|
||||
convert(self.0.tensor(name)?, device)
|
||||
}
|
||||
|
||||
pub fn tensors(&self, device: &Device) -> Result<Vec<(String, Tensor)>> {
|
||||
self.0
|
||||
.tensors()
|
||||
.into_iter()
|
||||
.map(|(name, tensor_view)| {
|
||||
let tensor = convert(tensor_view, device)?;
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn names(&self) -> Vec<&String> {
|
||||
self.0.names()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user