diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index ba7c16a8..ae93f4df 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -51,7 +51,11 @@ impl st::View for Tensor { } impl Tensor { - pub fn save>(&self, name: &str, filename: P) -> Result<()> { + pub fn save_safetensors>( + &self, + name: &str, + filename: P, + ) -> Result<()> { let data = [(name, self.clone())]; Ok(st::serialize_to_file(data, &None, filename.as_ref())?) } @@ -80,7 +84,7 @@ fn convert_(view: st::TensorView<'_>, device: &Device) -> Result(value: Cow<'_, [T]>) -> Result> { +fn convert_back_(value: Cow<'_, [T]>) -> Cow<'_, [u8]> { let size_in_bytes = T::DTYPE.size_in_bytes(); // SAFETY: // @@ -92,7 +96,7 @@ fn convert_back_(value: Cow<'_, [T]>) -> Result> { let slice = unsafe { std::slice::from_raw_parts(value.as_ptr() as *const u8, value.len() * size_in_bytes) }; - Ok(Cow::Borrowed(slice)) + Cow::Borrowed(slice) } pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { @@ -109,16 +113,12 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { pub fn convert_back(tensor: &Tensor) -> Result> { match tensor.dtype() { - DType::U8 => convert_back_::(tensor.storage_data()?), - DType::U32 => convert_back_::(tensor.storage_data()?), - DType::F16 => convert_back_::(tensor.storage_data()?), - DType::BF16 => convert_back_::(tensor.storage_data()?), - DType::F32 => convert_back_::(tensor.storage_data()?), - DType::F64 => convert_back_::(tensor.storage_data()?), - // DType::BF16 => convert_::(view, device), - // DType::F16 => convert_::(view, device), - // DType::F32 => convert_::(view, device), - // DType::F64 => convert_::(view, device), + DType::U8 => Ok(convert_back_::(tensor.storage_data()?)), + DType::U32 => Ok(convert_back_::(tensor.storage_data()?)), + DType::F16 => Ok(convert_back_::(tensor.storage_data()?)), + DType::BF16 => Ok(convert_back_::(tensor.storage_data()?)), + DType::F32 => Ok(convert_back_::(tensor.storage_data()?)), + DType::F64 => Ok(convert_back_::(tensor.storage_data()?)), } } @@ -179,7 +179,7 @@ mod tests { #[test] fn save_single_tensor() { let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); - t.save("t", "t.safetensors").unwrap(); + t.save_safetensors("t", "t.safetensors").unwrap(); let bytes = std::fs::read("t.safetensors").unwrap(); assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0"); std::fs::remove_file("t.safetensors").unwrap();