mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Final updates.
This commit is contained in:
@ -128,7 +128,7 @@ impl CpuStorage {
|
|||||||
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
||||||
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
||||||
|
|
||||||
if lhs_batch_stride != &[a_skip] || rhs_batch_stride != &[b_skip] {
|
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
|
||||||
// Temporary error before we support abitrary striding.
|
// Temporary error before we support abitrary striding.
|
||||||
return Err(Error::UnexpectedStriding);
|
return Err(Error::UnexpectedStriding);
|
||||||
}
|
}
|
||||||
|
@ -166,12 +166,12 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
||||||
let shape = array.shape()?.clone();
|
let shape = array.shape()?;
|
||||||
Self::new_impl(array, shape, device, false)
|
Self::new_impl(array, shape, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
||||||
let shape = array.shape()?.clone();
|
let shape = array.shape()?;
|
||||||
Self::new_impl(array, shape, device, true)
|
Self::new_impl(array, shape, device, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -259,11 +259,7 @@ impl Tensor {
|
|||||||
|
|
||||||
let dim = a_dims.len();
|
let dim = a_dims.len();
|
||||||
|
|
||||||
// TODO
|
if dim < 2 || b_dims.len() != dim {
|
||||||
// if dim < 2 {
|
|
||||||
// return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
|
|
||||||
// }
|
|
||||||
if b_dims.len() != dim {
|
|
||||||
return Err(Error::ShapeMismatchBinaryOp {
|
return Err(Error::ShapeMismatchBinaryOp {
|
||||||
lhs: self.shape().clone(),
|
lhs: self.shape().clone(),
|
||||||
rhs: rhs.shape().clone(),
|
rhs: rhs.shape().clone(),
|
||||||
|
Reference in New Issue
Block a user