use crate::{DType, Device, Error, Result, Tensor, WithDType}; use safetensors::tensor as st; use std::borrow::Cow; impl From for st::Dtype { fn from(value: DType) -> Self { match value { DType::U8 => st::Dtype::U8, DType::U32 => st::Dtype::U32, DType::BF16 => st::Dtype::BF16, DType::F16 => st::Dtype::F16, DType::F32 => st::Dtype::F32, DType::F64 => st::Dtype::F64, } } } impl TryFrom for DType { type Error = Error; fn try_from(value: st::Dtype) -> Result { match value { st::Dtype::U8 => Ok(DType::U8), st::Dtype::U32 => Ok(DType::U32), st::Dtype::BF16 => Ok(DType::BF16), st::Dtype::F16 => Ok(DType::F16), st::Dtype::F32 => Ok(DType::F32), st::Dtype::F64 => Ok(DType::F64), dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } } impl st::View for Tensor { fn dtype(&self) -> st::Dtype { self.dtype().into() } fn shape(&self) -> &[usize] { self.shape().dims() } fn data(&self) -> Cow<[u8]> { // This copies data from GPU to CPU. // TODO: Avoid the unwrap here. Cow::Owned(convert_back(self).unwrap()) } fn data_len(&self) -> usize { let n: usize = self.shape().elem_count(); let bytes_per_element = self.dtype().size_in_bytes(); n * bytes_per_element } } impl Tensor { pub fn save_safetensors>( &self, name: &str, filename: P, ) -> Result<()> { let data = [(name, self.clone())]; Ok(st::serialize_to_file(data, &None, filename.as_ref())?) } } fn convert_(view: st::TensorView<'_>, device: &Device) -> Result { let v = view.data(); let size_in_bytes = T::DTYPE.size_in_bytes(); let elem_count = v.len() / size_in_bytes; if (v.as_ptr() as usize) % size_in_bytes == 0 { // SAFETY This is safe because we just checked that this // was correctly aligned. let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) }; Tensor::from_slice(data, view.shape(), device) } else { // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access let mut c: Vec = Vec::with_capacity(elem_count); // SAFETY: We just created c, so the allocated memory is necessarily // contiguous and non overlapping with the view's data. // We're downgrading the `c` pointer from T to u8, which removes alignment // constraints. unsafe { std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len()); c.set_len(elem_count) } Tensor::from_slice(&c, view.shape(), device) } } fn convert_back_(mut vs: Vec) -> Vec { let size_in_bytes = T::DTYPE.size_in_bytes(); let length = vs.len() * size_in_bytes; let capacity = vs.capacity() * size_in_bytes; let ptr = vs.as_mut_ptr() as *mut u8; // Don't run the destructor for Vec std::mem::forget(vs); // SAFETY: // // Every T is larger than u8, so there is no issue regarding alignment. // This re-interpret the Vec as a Vec. unsafe { Vec::from_raw_parts(ptr, length, capacity) } } pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { match view.dtype() { st::Dtype::U8 => convert_::(view, device), st::Dtype::U32 => convert_::(view, device), st::Dtype::BF16 => convert_::(view, device), st::Dtype::F16 => convert_::(view, device), st::Dtype::F32 => convert_::(view, device), st::Dtype::F64 => convert_::(view, device), dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } pub fn convert_back(tensor: &Tensor) -> Result> { // TODO: This makes an unnecessary copy when the tensor is on the cpu. let tensor = tensor.flatten_all()?; match tensor.dtype() { DType::U8 => Ok(convert_back_::(tensor.to_vec1()?)), DType::U32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::BF16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F64 => Ok(convert_back_::(tensor.to_vec1()?)), } } // If Rust allowed for self-referential struct, we could store both the Mmap buffer and the // SafeTensor bits in the same struct and avoid having the final users calling two methods. // We could try using the ouroboros crate or equivalent for this at some point. // Wrap the SafeTensors main module so as to provide accessors with the candle types for errors, // dtypes, etc pub struct SafeTensors<'a>(st::SafeTensors<'a>); pub struct MmapedFile(memmap2::Mmap); impl MmapedFile { /// Creates a wrapper around a memory mapped file from which you can retrieve /// tensors using [`MmapedFile::deserialize`] /// /// # Safety /// /// The unsafe is inherited from [`memmap2::MmapOptions`]. pub unsafe fn new>(p: P) -> Result { let file = std::fs::File::open(p)?; let mmap = memmap2::MmapOptions::new().map(&file)?; Ok(Self(mmap)) } pub fn deserialize(&self) -> Result> { let st = safetensors::SafeTensors::deserialize(&self.0)?; Ok(SafeTensors(st)) } } impl<'a> SafeTensors<'a> { pub fn from_buffer(buffer: &'a [u8]) -> Result { let st = safetensors::SafeTensors::deserialize(buffer)?; Ok(SafeTensors(st)) } pub fn tensor(&self, name: &str, device: &Device) -> Result { convert(self.0.tensor(name)?, device) } pub fn tensors(&self, device: &Device) -> Result> { self.0 .tensors() .into_iter() .map(|(name, tensor_view)| { let tensor = convert(tensor_view, device)?; Ok((name, tensor)) }) .collect() } pub fn names(&self) -> Vec<&String> { self.0.names() } } #[cfg(test)] mod tests { use super::*; use std::collections::HashMap; #[test] fn save_single_tensor() { let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); t.save_safetensors("t", "t.safetensors").unwrap(); let bytes = std::fs::read("t.safetensors").unwrap(); assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); std::fs::remove_file("t.safetensors").unwrap(); } #[test] fn save_multiple_tensors() { let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap(); let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect(); st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap(); let bytes = std::fs::read("multi.safetensors").unwrap(); assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); std::fs::remove_file("multi.safetensors").unwrap(); } }