Adding saving capabilities.

This commit is contained in:
Nicolas Patry
2023-07-03 13:39:24 +02:00
parent d0d530dfdc
commit 5bc66c68fa

View File

@ -1,5 +1,61 @@
use crate::{Device, Error, Result, Tensor, WithDType};
use crate::{DType, Device, Error, Result, Tensor, WithDType};
use safetensors::tensor as st;
use std::borrow::Cow;
impl From<DType> for st::Dtype {
fn from(value: DType) -> Self {
match value {
DType::U8 => st::Dtype::U8,
DType::U32 => st::Dtype::U32,
DType::BF16 => st::Dtype::BF16,
DType::F16 => st::Dtype::F16,
DType::F32 => st::Dtype::F32,
DType::F64 => st::Dtype::F64,
}
}
}
impl TryFrom<st::Dtype> for DType {
type Error = Error;
fn try_from(value: st::Dtype) -> Result<Self> {
match value {
st::Dtype::U8 => Ok(DType::U8),
st::Dtype::U32 => Ok(DType::U32),
st::Dtype::BF16 => Ok(DType::BF16),
st::Dtype::F16 => Ok(DType::F16),
st::Dtype::F32 => Ok(DType::F32),
st::Dtype::F64 => Ok(DType::F64),
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
}
}
}
impl st::View for Tensor {
fn dtype(&self) -> st::Dtype {
self.dtype().into()
}
fn shape(&self) -> &[usize] {
self.shape().dims()
}
fn data(&self) -> Cow<[u8]> {
// This copies data from GPU to CPU.
convert_back(self).unwrap()
}
fn data_len(&self) -> usize {
let n: usize = self.shape().elem_count();
let bytes_per_element = self.dtype().size_in_bytes();
n * bytes_per_element
}
}
impl Tensor {
pub fn save<P: AsRef<std::path::Path>>(&self, name: &str, filename: P) -> Result<()> {
let data = [(name, self.clone())];
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
}
}
fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
let v = view.data();
@ -24,6 +80,21 @@ fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<T
}
}
fn convert_back_<T: WithDType>(value: Cow<'_, [T]>) -> Result<Cow<'_, [u8]>> {
let size_in_bytes = T::DTYPE.size_in_bytes();
// SAFETY:
//
// Every T is larger than u8, so there is no issue regarding alignment.
// This is safe only because we explicitly take the lifetime from the Cow's lifetime
// and consume the original Cow.
// This means that borrowed Cow, will keep their lifetime information, preventing
// this slice from being accessed after freeing the original memory.
let slice = unsafe {
std::slice::from_raw_parts(value.as_ptr() as *const u8, value.len() * size_in_bytes)
};
Ok(Cow::Borrowed(slice))
}
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
st::Dtype::U8 => convert_::<u8>(view, device),
@ -36,6 +107,21 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
}
}
pub fn convert_back(tensor: &Tensor) -> Result<Cow<[u8]>> {
match tensor.dtype() {
DType::U8 => convert_back_::<u8>(tensor.storage_data()?),
DType::U32 => convert_back_::<u32>(tensor.storage_data()?),
DType::F16 => convert_back_::<half::f16>(tensor.storage_data()?),
DType::BF16 => convert_back_::<half::bf16>(tensor.storage_data()?),
DType::F32 => convert_back_::<f32>(tensor.storage_data()?),
DType::F64 => convert_back_::<f64>(tensor.storage_data()?),
// DType::BF16 => convert_::<half::bf16>(view, device),
// DType::F16 => convert_::<half::f16>(view, device),
// DType::F32 => convert_::<f32>(view, device),
// DType::F64 => convert_::<f64>(view, device),
}
}
// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the
// SafeTensor bits in the same struct and avoid having the final users calling two methods.
// We could try using the ouroboros crate or equivalent for this at some point.
@ -84,3 +170,29 @@ impl<'a> SafeTensors<'a> {
self.0.names()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn save_single_tensor() {
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
t.save("t", "t.safetensors").unwrap();
let bytes = std::fs::read("t.safetensors").unwrap();
assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0");
std::fs::remove_file("t.safetensors").unwrap();
}
#[test]
fn save_multiple_tensors() {
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap();
let bytes = std::fs::read("multi.safetensors").unwrap();
assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0");
std::fs::remove_file("multi.safetensors").unwrap();
}
}