Final updates.

This commit is contained in:
Nicolas Patry
2023-06-22 12:39:33 +02:00
parent 04cf14f35a
commit a8b6c848e0
2 changed files with 4 additions and 8 deletions

View File

@ -128,7 +128,7 @@ impl CpuStorage {
let lhs_batch_stride = &lhs_stride[..rank - 2]; let lhs_batch_stride = &lhs_stride[..rank - 2];
let rhs_batch_stride = &rhs_stride[..rank - 2]; let rhs_batch_stride = &rhs_stride[..rank - 2];
if lhs_batch_stride != &[a_skip] || rhs_batch_stride != &[b_skip] { if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
// Temporary error before we support abitrary striding. // Temporary error before we support abitrary striding.
return Err(Error::UnexpectedStriding); return Err(Error::UnexpectedStriding);
} }

View File

@ -166,12 +166,12 @@ impl Tensor {
} }
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
let shape = array.shape()?.clone(); let shape = array.shape()?;
Self::new_impl(array, shape, device, false) Self::new_impl(array, shape, device, false)
} }
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
let shape = array.shape()?.clone(); let shape = array.shape()?;
Self::new_impl(array, shape, device, true) Self::new_impl(array, shape, device, true)
} }
@ -259,11 +259,7 @@ impl Tensor {
let dim = a_dims.len(); let dim = a_dims.len();
// TODO if dim < 2 || b_dims.len() != dim {
// if dim < 2 {
// return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
// }
if b_dims.len() != dim {
return Err(Error::ShapeMismatchBinaryOp { return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(), lhs: self.shape().clone(),
rhs: rhs.shape().clone(), rhs: rhs.shape().clone(),