mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Addressing comments.
This commit is contained in:
@ -151,8 +151,8 @@ impl Tensor {
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let n: usize = shape.0.iter().product();
|
let n: usize = shape.elem_count();
|
||||||
let buffer_size: usize = array.shape()?.0.iter().product();
|
let buffer_size: usize = array.shape()?.elem_count();
|
||||||
if buffer_size != n {
|
if buffer_size != n {
|
||||||
return Err(Error::ShapeMismatch { buffer_size, shape });
|
return Err(Error::ShapeMismatch { buffer_size, shape });
|
||||||
}
|
}
|
||||||
@ -285,7 +285,7 @@ impl Tensor {
|
|||||||
|
|
||||||
let mut c_shape: Vec<_> = a_dims[..dim - 2].into();
|
let mut c_shape: Vec<_> = a_dims[..dim - 2].into();
|
||||||
c_shape.extend(&[m, n]);
|
c_shape.extend(&[m, n]);
|
||||||
let c_shape: Shape = Shape(c_shape);
|
let c_shape = Shape(c_shape);
|
||||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||||
|
|
||||||
let storage = self.storage.matmul_impl(
|
let storage = self.storage.matmul_impl(
|
||||||
|
Reference in New Issue
Block a user