diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 99e11c60..ba7c16a8 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -1,5 +1,61 @@ -use crate::{Device, Error, Result, Tensor, WithDType}; +use crate::{DType, Device, Error, Result, Tensor, WithDType}; use safetensors::tensor as st; +use std::borrow::Cow; + +impl From for st::Dtype { + fn from(value: DType) -> Self { + match value { + DType::U8 => st::Dtype::U8, + DType::U32 => st::Dtype::U32, + DType::BF16 => st::Dtype::BF16, + DType::F16 => st::Dtype::F16, + DType::F32 => st::Dtype::F32, + DType::F64 => st::Dtype::F64, + } + } +} + +impl TryFrom for DType { + type Error = Error; + fn try_from(value: st::Dtype) -> Result { + match value { + st::Dtype::U8 => Ok(DType::U8), + st::Dtype::U32 => Ok(DType::U32), + st::Dtype::BF16 => Ok(DType::BF16), + st::Dtype::F16 => Ok(DType::F16), + st::Dtype::F32 => Ok(DType::F32), + st::Dtype::F64 => Ok(DType::F64), + dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), + } + } +} + +impl st::View for Tensor { + fn dtype(&self) -> st::Dtype { + self.dtype().into() + } + fn shape(&self) -> &[usize] { + self.shape().dims() + } + + fn data(&self) -> Cow<[u8]> { + // This copies data from GPU to CPU. + convert_back(self).unwrap() + } + + fn data_len(&self) -> usize { + let n: usize = self.shape().elem_count(); + let bytes_per_element = self.dtype().size_in_bytes(); + n * bytes_per_element + } +} + +impl Tensor { + pub fn save>(&self, name: &str, filename: P) -> Result<()> { + let data = [(name, self.clone())]; + Ok(st::serialize_to_file(data, &None, filename.as_ref())?) + } +} fn convert_(view: st::TensorView<'_>, device: &Device) -> Result { let v = view.data(); @@ -24,6 +80,21 @@ fn convert_(view: st::TensorView<'_>, device: &Device) -> Result(value: Cow<'_, [T]>) -> Result> { + let size_in_bytes = T::DTYPE.size_in_bytes(); + // SAFETY: + // + // Every T is larger than u8, so there is no issue regarding alignment. + // This is safe only because we explicitly take the lifetime from the Cow's lifetime + // and consume the original Cow. + // This means that borrowed Cow, will keep their lifetime information, preventing + // this slice from being accessed after freeing the original memory. + let slice = unsafe { + std::slice::from_raw_parts(value.as_ptr() as *const u8, value.len() * size_in_bytes) + }; + Ok(Cow::Borrowed(slice)) +} + pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { match view.dtype() { st::Dtype::U8 => convert_::(view, device), @@ -36,6 +107,21 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { } } +pub fn convert_back(tensor: &Tensor) -> Result> { + match tensor.dtype() { + DType::U8 => convert_back_::(tensor.storage_data()?), + DType::U32 => convert_back_::(tensor.storage_data()?), + DType::F16 => convert_back_::(tensor.storage_data()?), + DType::BF16 => convert_back_::(tensor.storage_data()?), + DType::F32 => convert_back_::(tensor.storage_data()?), + DType::F64 => convert_back_::(tensor.storage_data()?), + // DType::BF16 => convert_::(view, device), + // DType::F16 => convert_::(view, device), + // DType::F32 => convert_::(view, device), + // DType::F64 => convert_::(view, device), + } +} + // If Rust allowed for self-referential struct, we could store both the Mmap buffer and the // SafeTensor bits in the same struct and avoid having the final users calling two methods. // We could try using the ouroboros crate or equivalent for this at some point. @@ -84,3 +170,29 @@ impl<'a> SafeTensors<'a> { self.0.names() } } + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn save_single_tensor() { + let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); + t.save("t", "t.safetensors").unwrap(); + let bytes = std::fs::read("t.safetensors").unwrap(); + assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0"); + std::fs::remove_file("t.safetensors").unwrap(); + } + + #[test] + fn save_multiple_tensors() { + let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); + let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap(); + let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect(); + st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap(); + let bytes = std::fs::read("multi.safetensors").unwrap(); + assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0"); + std::fs::remove_file("multi.safetensors").unwrap(); + } +}