Final updates.

This commit is contained in:
Nicolas Patry
2023-06-22 12:39:33 +02:00
parent 04cf14f35a
commit a8b6c848e0
2 changed files with 4 additions and 8 deletions

View File

@ -128,7 +128,7 @@ impl CpuStorage {
let lhs_batch_stride = &lhs_stride[..rank - 2];
let rhs_batch_stride = &rhs_stride[..rank - 2];
if lhs_batch_stride != &[a_skip] || rhs_batch_stride != &[b_skip] {
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
// Temporary error before we support abitrary striding.
return Err(Error::UnexpectedStriding);
}

View File

@ -166,12 +166,12 @@ impl Tensor {
}
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
let shape = array.shape()?.clone();
let shape = array.shape()?;
Self::new_impl(array, shape, device, false)
}
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
let shape = array.shape()?.clone();
let shape = array.shape()?;
Self::new_impl(array, shape, device, true)
}
@ -259,11 +259,7 @@ impl Tensor {
let dim = a_dims.len();
// TODO
// if dim < 2 {
// return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
// }
if b_dims.len() != dim {
if dim < 2 || b_dims.len() != dim {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),