mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Adding size checking when creating a tensor from buffer + shape.
This commit is contained in:
@ -12,6 +12,11 @@ pub enum Error {
|
|||||||
#[error("the candle crate has not been built with cuda support")]
|
#[error("the candle crate has not been built with cuda support")]
|
||||||
NotCompiledWithCudaSupport,
|
NotCompiledWithCudaSupport,
|
||||||
|
|
||||||
|
#[error(
|
||||||
|
"Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
|
||||||
|
)]
|
||||||
|
ShapeMismatch { buffer_size: usize, shape: Shape },
|
||||||
|
|
||||||
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||||
ShapeMismatchBinaryOp {
|
ShapeMismatchBinaryOp {
|
||||||
lhs: Shape,
|
lhs: Shape,
|
||||||
@ -40,6 +45,7 @@ pub enum Error {
|
|||||||
shape: Shape,
|
shape: Shape,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// TODO this is temporary when we support arbitrary matmul
|
||||||
#[error("temporary error where matmul doesn't support arbitrary striding")]
|
#[error("temporary error where matmul doesn't support arbitrary striding")]
|
||||||
UnexpectedStriding,
|
UnexpectedStriding,
|
||||||
|
|
||||||
|
@ -151,7 +151,11 @@ impl Tensor {
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// let shape = array.shape()?;
|
let n: usize = shape.0.iter().product();
|
||||||
|
let buffer_size: usize = array.shape()?.0.iter().product();
|
||||||
|
if buffer_size != n {
|
||||||
|
return Err(Error::ShapeMismatch { buffer_size, shape });
|
||||||
|
}
|
||||||
let storage = device.storage(array)?;
|
let storage = device.storage(array)?;
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
|
Reference in New Issue
Block a user