diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 6a06836d..3fdcbcc6 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -3,7 +3,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp}; use crate::scalar::TensorOrScalar; -use crate::shape::{Dim, Dims}; +use crate::shape::{Dim, Dims, ShapeWithOneHole}; use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; @@ -452,17 +452,13 @@ impl Tensor { Self::from_vec_impl(data, len, device, false) } - pub(crate) fn from_vec_impl, D: crate::WithDType>( + pub(crate) fn from_vec_impl( data: Vec, shape: S, device: &Device, is_variable: bool, ) -> Result { - let shape = shape.into(); - let buffer_size = data.len(); - if buffer_size != shape.elem_count() { - return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); - } + let shape = shape.into_shape(data.len())?; let storage = device.storage_owned(data)?; let none = BackpropOp::none(); Ok(from_storage(storage, shape, none, is_variable)) @@ -481,7 +477,7 @@ impl Tensor { /// ]); /// # Ok::<(), candle_core::Error>(()) /// ``` - pub fn from_vec, D: crate::WithDType>( + pub fn from_vec( data: Vec, shape: S, device: &Device, @@ -502,17 +498,12 @@ impl Tensor { /// ]); /// # Ok::<(), candle_core::Error>(()) /// ``` - pub fn from_slice, D: crate::WithDType>( + pub fn from_slice( array: &[D], shape: S, device: &Device, ) -> Result { - let shape = shape.into(); - let n: usize = shape.elem_count(); - let buffer_size: usize = array.len(); - if buffer_size != n { - return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); - } + let shape = shape.into_shape(array.len())?; let storage = device.storage_from_slice(array)?; let none = BackpropOp::none(); Ok(from_storage(storage, shape, none, false)) @@ -2197,7 +2188,7 @@ impl Tensor { /// /// # Ok::<(), candle_core::Error>(()) /// ``` - pub fn reshape(&self, s: S) -> Result { + pub fn reshape(&self, s: S) -> Result { let shape = s.into_shape(self.elem_count())?; if shape.elem_count() != self.elem_count() { return Err(Error::ShapeMismatchBinaryOp {