Move more safetensors bits to the shared module.

This commit is contained in:
laurent
2023-07-03 09:34:08 +01:00
parent b036faf6a0
commit 783b7054ee
3 changed files with 59 additions and 24 deletions

View File

@ -25,3 +25,46 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
dt => todo!("Unhandled dtype {dt:?}"),
}
}
// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the
// SafeTensor bits in the same struct and avoid having the final users calling two methods.
// We could try using the ouroboros crate or equivalent for this at some point.
// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors,
// dtypes, etc
pub struct SafeTensors<'a>(st::SafeTensors<'a>);
pub struct MmapedFile(memmap2::Mmap);
impl MmapedFile {
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let file = std::fs::File::open(p)?;
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
Ok(Self(mmap))
}
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
let st = safetensors::SafeTensors::deserialize(&self.0)?;
Ok(SafeTensors(st))
}
}
impl<'a> SafeTensors<'a> {
pub fn tensor(&self, name: &str, device: &Device) -> Result<Tensor> {
convert(self.0.tensor(name)?, device)
}
pub fn tensors(&self, device: &Device) -> Result<Vec<(String, Tensor)>> {
self.0
.tensors()
.into_iter()
.map(|(name, tensor_view)| {
let tensor = convert(tensor_view, device)?;
Ok((name, tensor))
})
.collect()
}
pub fn names(&self) -> Vec<&String> {
self.0.names()
}
}