mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Address comments.
This commit is contained in:
@ -51,7 +51,11 @@ impl st::View for Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
pub fn save<P: AsRef<std::path::Path>>(&self, name: &str, filename: P) -> Result<()> {
|
pub fn save_safetensors<P: AsRef<std::path::Path>>(
|
||||||
|
&self,
|
||||||
|
name: &str,
|
||||||
|
filename: P,
|
||||||
|
) -> Result<()> {
|
||||||
let data = [(name, self.clone())];
|
let data = [(name, self.clone())];
|
||||||
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
|
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
|
||||||
}
|
}
|
||||||
@ -80,7 +84,7 @@ fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<T
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_back_<T: WithDType>(value: Cow<'_, [T]>) -> Result<Cow<'_, [u8]>> {
|
fn convert_back_<T: WithDType>(value: Cow<'_, [T]>) -> Cow<'_, [u8]> {
|
||||||
let size_in_bytes = T::DTYPE.size_in_bytes();
|
let size_in_bytes = T::DTYPE.size_in_bytes();
|
||||||
// SAFETY:
|
// SAFETY:
|
||||||
//
|
//
|
||||||
@ -92,7 +96,7 @@ fn convert_back_<T: WithDType>(value: Cow<'_, [T]>) -> Result<Cow<'_, [u8]>> {
|
|||||||
let slice = unsafe {
|
let slice = unsafe {
|
||||||
std::slice::from_raw_parts(value.as_ptr() as *const u8, value.len() * size_in_bytes)
|
std::slice::from_raw_parts(value.as_ptr() as *const u8, value.len() * size_in_bytes)
|
||||||
};
|
};
|
||||||
Ok(Cow::Borrowed(slice))
|
Cow::Borrowed(slice)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||||
@ -109,16 +113,12 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
|||||||
|
|
||||||
pub fn convert_back(tensor: &Tensor) -> Result<Cow<[u8]>> {
|
pub fn convert_back(tensor: &Tensor) -> Result<Cow<[u8]>> {
|
||||||
match tensor.dtype() {
|
match tensor.dtype() {
|
||||||
DType::U8 => convert_back_::<u8>(tensor.storage_data()?),
|
DType::U8 => Ok(convert_back_::<u8>(tensor.storage_data()?)),
|
||||||
DType::U32 => convert_back_::<u32>(tensor.storage_data()?),
|
DType::U32 => Ok(convert_back_::<u32>(tensor.storage_data()?)),
|
||||||
DType::F16 => convert_back_::<half::f16>(tensor.storage_data()?),
|
DType::F16 => Ok(convert_back_::<half::f16>(tensor.storage_data()?)),
|
||||||
DType::BF16 => convert_back_::<half::bf16>(tensor.storage_data()?),
|
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.storage_data()?)),
|
||||||
DType::F32 => convert_back_::<f32>(tensor.storage_data()?),
|
DType::F32 => Ok(convert_back_::<f32>(tensor.storage_data()?)),
|
||||||
DType::F64 => convert_back_::<f64>(tensor.storage_data()?),
|
DType::F64 => Ok(convert_back_::<f64>(tensor.storage_data()?)),
|
||||||
// DType::BF16 => convert_::<half::bf16>(view, device),
|
|
||||||
// DType::F16 => convert_::<half::f16>(view, device),
|
|
||||||
// DType::F32 => convert_::<f32>(view, device),
|
|
||||||
// DType::F64 => convert_::<f64>(view, device),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,7 +179,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn save_single_tensor() {
|
fn save_single_tensor() {
|
||||||
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
|
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
|
||||||
t.save("t", "t.safetensors").unwrap();
|
t.save_safetensors("t", "t.safetensors").unwrap();
|
||||||
let bytes = std::fs::read("t.safetensors").unwrap();
|
let bytes = std::fs::read("t.safetensors").unwrap();
|
||||||
assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0");
|
assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0");
|
||||||
std::fs::remove_file("t.safetensors").unwrap();
|
std::fs::remove_file("t.safetensors").unwrap();
|
||||||
|
Reference in New Issue
Block a user