mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Tensor mutability (#154)
* Working towards tensor mutability. * Use a ref-cell to provide tensor mutability.
This commit is contained in:
@ -40,7 +40,8 @@ impl st::View for Tensor {
|
||||
|
||||
fn data(&self) -> Cow<[u8]> {
|
||||
// This copies data from GPU to CPU.
|
||||
convert_back(self).unwrap()
|
||||
// TODO: Avoid the unwrap here.
|
||||
Cow::Owned(convert_back(self).unwrap())
|
||||
}
|
||||
|
||||
fn data_len(&self) -> usize {
|
||||
@ -86,19 +87,18 @@ fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<T
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_back_<T: WithDType>(value: Cow<'_, [T]>) -> Cow<'_, [u8]> {
|
||||
fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
|
||||
let size_in_bytes = T::DTYPE.size_in_bytes();
|
||||
let length = vs.len() * size_in_bytes;
|
||||
let capacity = vs.capacity() * size_in_bytes;
|
||||
let ptr = vs.as_mut_ptr() as *mut u8;
|
||||
// Don't run the destructor for Vec<T>
|
||||
std::mem::forget(vs);
|
||||
// SAFETY:
|
||||
//
|
||||
// Every T is larger than u8, so there is no issue regarding alignment.
|
||||
// This is safe only because we explicitly take the lifetime from the Cow's lifetime
|
||||
// and consume the original Cow.
|
||||
// This means that borrowed Cow, will keep their lifetime information, preventing
|
||||
// this slice from being accessed after freeing the original memory.
|
||||
let slice = unsafe {
|
||||
std::slice::from_raw_parts(value.as_ptr() as *const u8, value.len() * size_in_bytes)
|
||||
};
|
||||
Cow::Borrowed(slice)
|
||||
// This re-interpret the Vec<T> as a Vec<u8>.
|
||||
unsafe { Vec::from_raw_parts(ptr, length, capacity) }
|
||||
}
|
||||
|
||||
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
@ -113,14 +113,16 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_back(tensor: &Tensor) -> Result<Cow<[u8]>> {
|
||||
pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
// TODO: This makes an unnecessary copy when the tensor is on the cpu.
|
||||
let tensor = tensor.flatten_all()?;
|
||||
match tensor.dtype() {
|
||||
DType::U8 => Ok(convert_back_::<u8>(tensor.storage_data()?)),
|
||||
DType::U32 => Ok(convert_back_::<u32>(tensor.storage_data()?)),
|
||||
DType::F16 => Ok(convert_back_::<half::f16>(tensor.storage_data()?)),
|
||||
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.storage_data()?)),
|
||||
DType::F32 => Ok(convert_back_::<f32>(tensor.storage_data()?)),
|
||||
DType::F64 => Ok(convert_back_::<f64>(tensor.storage_data()?)),
|
||||
DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
|
||||
DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
|
||||
DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
|
||||
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
|
||||
DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
|
||||
DType::F64 => Ok(convert_back_::<f64>(tensor.to_vec1()?)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -183,7 +185,7 @@ mod tests {
|
||||
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
|
||||
t.save_safetensors("t", "t.safetensors").unwrap();
|
||||
let bytes = std::fs::read("t.safetensors").unwrap();
|
||||
assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0");
|
||||
assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
|
||||
std::fs::remove_file("t.safetensors").unwrap();
|
||||
}
|
||||
|
||||
@ -194,7 +196,7 @@ mod tests {
|
||||
let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
|
||||
st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap();
|
||||
let bytes = std::fs::read("multi.safetensors").unwrap();
|
||||
assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0");
|
||||
assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
|
||||
std::fs::remove_file("multi.safetensors").unwrap();
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user