Adding size checking when creating a tensor from buffer + shape.

This commit is contained in:
Nicolas Patry
2023-06-22 13:08:57 +02:00
parent a8b6c848e0
commit 449af49b54
2 changed files with 11 additions and 1 deletions

View File

@ -12,6 +12,11 @@ pub enum Error {
#[error("the candle crate has not been built with cuda support")] #[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport, NotCompiledWithCudaSupport,
#[error(
"Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
)]
ShapeMismatch { buffer_size: usize, shape: Shape },
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
ShapeMismatchBinaryOp { ShapeMismatchBinaryOp {
lhs: Shape, lhs: Shape,
@ -40,6 +45,7 @@ pub enum Error {
shape: Shape, shape: Shape,
}, },
// TODO this is temporary when we support arbitrary matmul
#[error("temporary error where matmul doesn't support arbitrary striding")] #[error("temporary error where matmul doesn't support arbitrary striding")]
UnexpectedStriding, UnexpectedStriding,

View File

@ -151,7 +151,11 @@ impl Tensor {
device: &Device, device: &Device,
is_variable: bool, is_variable: bool,
) -> Result<Self> { ) -> Result<Self> {
// let shape = array.shape()?; let n: usize = shape.0.iter().product();
let buffer_size: usize = array.shape()?.0.iter().product();
if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape });
}
let storage = device.storage(array)?; let storage = device.storage(array)?;
let stride = shape.stride_contiguous(); let stride = shape.stride_contiguous();
let tensor_ = Tensor_ { let tensor_ = Tensor_ {