Add StorageRef. (#2113)

* Add the storage-ref bits.

* Add the metal implementation.
This commit is contained in:
Laurent Mazare
2024-04-23 13:23:27 +02:00
committed by GitHub
parent b2e816752b
commit 8a05743a21
10 changed files with 108 additions and 5 deletions

View File

@ -456,7 +456,15 @@ impl Tensor {
shape: S,
device: &Device,
) -> Result<Self> {
Self::new_impl(array, shape.into(), device, false)
let shape = shape.into();
let n: usize = shape.elem_count();
let buffer_size: usize = array.len();
if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let storage = device.storage_from_slice(array)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, false))
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {