Addressing comments.

This commit is contained in:
Nicolas Patry
2023-06-22 13:13:35 +02:00
parent 449af49b54
commit 77712d4348

View File

@ -151,8 +151,8 @@ impl Tensor {
device: &Device,
is_variable: bool,
) -> Result<Self> {
let n: usize = shape.0.iter().product();
let buffer_size: usize = array.shape()?.0.iter().product();
let n: usize = shape.elem_count();
let buffer_size: usize = array.shape()?.elem_count();
if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape });
}
@ -285,7 +285,7 @@ impl Tensor {
let mut c_shape: Vec<_> = a_dims[..dim - 2].into();
c_shape.extend(&[m, n]);
let c_shape: Shape = Shape(c_shape);
let c_shape = Shape(c_shape);
let batching: usize = a_dims[..dim - 2].iter().product();
let storage = self.storage.matmul_impl(