From 77712d4348a31fb2e8f9676421ed05f3b5c2292e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 22 Jun 2023 13:13:35 +0200 Subject: [PATCH] Addressing comments. --- src/tensor.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index 40b72c00..09e5d66c 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -151,8 +151,8 @@ impl Tensor { device: &Device, is_variable: bool, ) -> Result { - let n: usize = shape.0.iter().product(); - let buffer_size: usize = array.shape()?.0.iter().product(); + let n: usize = shape.elem_count(); + let buffer_size: usize = array.shape()?.elem_count(); if buffer_size != n { return Err(Error::ShapeMismatch { buffer_size, shape }); } @@ -285,7 +285,7 @@ impl Tensor { let mut c_shape: Vec<_> = a_dims[..dim - 2].into(); c_shape.extend(&[m, n]); - let c_shape: Shape = Shape(c_shape); + let c_shape = Shape(c_shape); let batching: usize = a_dims[..dim - 2].iter().product(); let storage = self.storage.matmul_impl(