From 449af49b5404b96ae19e6926921571c459abb183 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 22 Jun 2023 13:08:57 +0200 Subject: [PATCH] Adding size checking when creating a tensor from buffer + shape. --- src/error.rs | 6 ++++++ src/tensor.rs | 6 +++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/error.rs b/src/error.rs index 6f40622c..723edaa1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,6 +12,11 @@ pub enum Error { #[error("the candle crate has not been built with cuda support")] NotCompiledWithCudaSupport, + #[error( + "Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}" + )] + ShapeMismatch { buffer_size: usize, shape: Shape }, + #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] ShapeMismatchBinaryOp { lhs: Shape, @@ -40,6 +45,7 @@ pub enum Error { shape: Shape, }, + // TODO this is temporary when we support arbitrary matmul #[error("temporary error where matmul doesn't support arbitrary striding")] UnexpectedStriding, diff --git a/src/tensor.rs b/src/tensor.rs index 571b0399..40b72c00 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -151,7 +151,11 @@ impl Tensor { device: &Device, is_variable: bool, ) -> Result { - // let shape = array.shape()?; + let n: usize = shape.0.iter().product(); + let buffer_size: usize = array.shape()?.0.iter().product(); + if buffer_size != n { + return Err(Error::ShapeMismatch { buffer_size, shape }); + } let storage = device.storage(array)?; let stride = shape.stride_contiguous(); let tensor_ = Tensor_ {