From 5bc66c68fa56c84f6c0457e71ef115b83fdc7428 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 3 Jul 2023 13:39:24 +0200 Subject: [PATCH 1/2] Adding saving capabilities. --- candle-core/src/safetensors.rs | 114 ++++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 1 deletion(-) diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 99e11c60..ba7c16a8 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -1,5 +1,61 @@ -use crate::{Device, Error, Result, Tensor, WithDType}; +use crate::{DType, Device, Error, Result, Tensor, WithDType}; use safetensors::tensor as st; +use std::borrow::Cow; + +impl From for st::Dtype { + fn from(value: DType) -> Self { + match value { + DType::U8 => st::Dtype::U8, + DType::U32 => st::Dtype::U32, + DType::BF16 => st::Dtype::BF16, + DType::F16 => st::Dtype::F16, + DType::F32 => st::Dtype::F32, + DType::F64 => st::Dtype::F64, + } + } +} + +impl TryFrom for DType { + type Error = Error; + fn try_from(value: st::Dtype) -> Result { + match value { + st::Dtype::U8 => Ok(DType::U8), + st::Dtype::U32 => Ok(DType::U32), + st::Dtype::BF16 => Ok(DType::BF16), + st::Dtype::F16 => Ok(DType::F16), + st::Dtype::F32 => Ok(DType::F32), + st::Dtype::F64 => Ok(DType::F64), + dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), + } + } +} + +impl st::View for Tensor { + fn dtype(&self) -> st::Dtype { + self.dtype().into() + } + fn shape(&self) -> &[usize] { + self.shape().dims() + } + + fn data(&self) -> Cow<[u8]> { + // This copies data from GPU to CPU. + convert_back(self).unwrap() + } + + fn data_len(&self) -> usize { + let n: usize = self.shape().elem_count(); + let bytes_per_element = self.dtype().size_in_bytes(); + n * bytes_per_element + } +} + +impl Tensor { + pub fn save>(&self, name: &str, filename: P) -> Result<()> { + let data = [(name, self.clone())]; + Ok(st::serialize_to_file(data, &None, filename.as_ref())?) + } +} fn convert_(view: st::TensorView<'_>, device: &Device) -> Result { let v = view.data(); @@ -24,6 +80,21 @@ fn convert_(view: st::TensorView<'_>, device: &Device) -> Result(value: Cow<'_, [T]>) -> Result> { + let size_in_bytes = T::DTYPE.size_in_bytes(); + // SAFETY: + // + // Every T is larger than u8, so there is no issue regarding alignment. + // This is safe only because we explicitly take the lifetime from the Cow's lifetime + // and consume the original Cow. + // This means that borrowed Cow, will keep their lifetime information, preventing + // this slice from being accessed after freeing the original memory. + let slice = unsafe { + std::slice::from_raw_parts(value.as_ptr() as *const u8, value.len() * size_in_bytes) + }; + Ok(Cow::Borrowed(slice)) +} + pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { match view.dtype() { st::Dtype::U8 => convert_::(view, device), @@ -36,6 +107,21 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { } } +pub fn convert_back(tensor: &Tensor) -> Result> { + match tensor.dtype() { + DType::U8 => convert_back_::(tensor.storage_data()?), + DType::U32 => convert_back_::(tensor.storage_data()?), + DType::F16 => convert_back_::(tensor.storage_data()?), + DType::BF16 => convert_back_::(tensor.storage_data()?), + DType::F32 => convert_back_::(tensor.storage_data()?), + DType::F64 => convert_back_::(tensor.storage_data()?), + // DType::BF16 => convert_::(view, device), + // DType::F16 => convert_::(view, device), + // DType::F32 => convert_::(view, device), + // DType::F64 => convert_::(view, device), + } +} + // If Rust allowed for self-referential struct, we could store both the Mmap buffer and the // SafeTensor bits in the same struct and avoid having the final users calling two methods. // We could try using the ouroboros crate or equivalent for this at some point. @@ -84,3 +170,29 @@ impl<'a> SafeTensors<'a> { self.0.names() } } + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn save_single_tensor() { + let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); + t.save("t", "t.safetensors").unwrap(); + let bytes = std::fs::read("t.safetensors").unwrap(); + assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0"); + std::fs::remove_file("t.safetensors").unwrap(); + } + + #[test] + fn save_multiple_tensors() { + let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); + let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap(); + let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect(); + st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap(); + let bytes = std::fs::read("multi.safetensors").unwrap(); + assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0"); + std::fs::remove_file("multi.safetensors").unwrap(); + } +} From 0b3cc215f14345b2685a1ff1dc62da087bdc7abb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 3 Jul 2023 13:52:27 +0200 Subject: [PATCH 2/2] Address comments. --- candle-core/src/safetensors.rs | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index ba7c16a8..ae93f4df 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -51,7 +51,11 @@ impl st::View for Tensor { } impl Tensor { - pub fn save>(&self, name: &str, filename: P) -> Result<()> { + pub fn save_safetensors>( + &self, + name: &str, + filename: P, + ) -> Result<()> { let data = [(name, self.clone())]; Ok(st::serialize_to_file(data, &None, filename.as_ref())?) } @@ -80,7 +84,7 @@ fn convert_(view: st::TensorView<'_>, device: &Device) -> Result(value: Cow<'_, [T]>) -> Result> { +fn convert_back_(value: Cow<'_, [T]>) -> Cow<'_, [u8]> { let size_in_bytes = T::DTYPE.size_in_bytes(); // SAFETY: // @@ -92,7 +96,7 @@ fn convert_back_(value: Cow<'_, [T]>) -> Result> { let slice = unsafe { std::slice::from_raw_parts(value.as_ptr() as *const u8, value.len() * size_in_bytes) }; - Ok(Cow::Borrowed(slice)) + Cow::Borrowed(slice) } pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { @@ -109,16 +113,12 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { pub fn convert_back(tensor: &Tensor) -> Result> { match tensor.dtype() { - DType::U8 => convert_back_::(tensor.storage_data()?), - DType::U32 => convert_back_::(tensor.storage_data()?), - DType::F16 => convert_back_::(tensor.storage_data()?), - DType::BF16 => convert_back_::(tensor.storage_data()?), - DType::F32 => convert_back_::(tensor.storage_data()?), - DType::F64 => convert_back_::(tensor.storage_data()?), - // DType::BF16 => convert_::(view, device), - // DType::F16 => convert_::(view, device), - // DType::F32 => convert_::(view, device), - // DType::F64 => convert_::(view, device), + DType::U8 => Ok(convert_back_::(tensor.storage_data()?)), + DType::U32 => Ok(convert_back_::(tensor.storage_data()?)), + DType::F16 => Ok(convert_back_::(tensor.storage_data()?)), + DType::BF16 => Ok(convert_back_::(tensor.storage_data()?)), + DType::F32 => Ok(convert_back_::(tensor.storage_data()?)), + DType::F64 => Ok(convert_back_::(tensor.storage_data()?)), } } @@ -179,7 +179,7 @@ mod tests { #[test] fn save_single_tensor() { let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); - t.save("t", "t.safetensors").unwrap(); + t.save_safetensors("t", "t.safetensors").unwrap(); let bytes = std::fs::read("t.safetensors").unwrap(); assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0"); std::fs::remove_file("t.safetensors").unwrap();