mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Final updates.
This commit is contained in:
@ -128,7 +128,7 @@ impl CpuStorage {
|
||||
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
||||
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
||||
|
||||
if lhs_batch_stride != &[a_skip] || rhs_batch_stride != &[b_skip] {
|
||||
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
|
||||
// Temporary error before we support abitrary striding.
|
||||
return Err(Error::UnexpectedStriding);
|
||||
}
|
||||
|
@ -166,12 +166,12 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
||||
let shape = array.shape()?.clone();
|
||||
let shape = array.shape()?;
|
||||
Self::new_impl(array, shape, device, false)
|
||||
}
|
||||
|
||||
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
||||
let shape = array.shape()?.clone();
|
||||
let shape = array.shape()?;
|
||||
Self::new_impl(array, shape, device, true)
|
||||
}
|
||||
|
||||
@ -259,11 +259,7 @@ impl Tensor {
|
||||
|
||||
let dim = a_dims.len();
|
||||
|
||||
// TODO
|
||||
// if dim < 2 {
|
||||
// return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
|
||||
// }
|
||||
if b_dims.len() != dim {
|
||||
if dim < 2 || b_dims.len() != dim {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
|
Reference in New Issue
Block a user