diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 0eb4270a..2c708389 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -128,7 +128,7 @@ impl CpuStorage { let lhs_batch_stride = &lhs_stride[..rank - 2]; let rhs_batch_stride = &rhs_stride[..rank - 2]; - if lhs_batch_stride != &[a_skip] || rhs_batch_stride != &[b_skip] { + if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] { // Temporary error before we support abitrary striding. return Err(Error::UnexpectedStriding); } diff --git a/src/tensor.rs b/src/tensor.rs index 7274c557..571b0399 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -166,12 +166,12 @@ impl Tensor { } pub fn new(array: A, device: &Device) -> Result { - let shape = array.shape()?.clone(); + let shape = array.shape()?; Self::new_impl(array, shape, device, false) } pub fn var(array: A, device: &Device) -> Result { - let shape = array.shape()?.clone(); + let shape = array.shape()?; Self::new_impl(array, shape, device, true) } @@ -259,11 +259,7 @@ impl Tensor { let dim = a_dims.len(); - // TODO - // if dim < 2 { - // return Err(SmeltError::InsufficientRank { minimum_rank: 2 }); - // } - if b_dims.len() != dim { + if dim < 2 || b_dims.len() != dim { return Err(Error::ShapeMismatchBinaryOp { lhs: self.shape().clone(), rhs: rhs.shape().clone(),