diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 00000000..12386268 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,19 @@ +/// Main library error type. +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("invalid shapes in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] + BinaryInvalidShape { + lhs: Vec, + rhs: Vec, + op: &'static str, + }, + + #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")] + UnexpectedNumberOfDims { + expected: usize, + got: usize, + shape: Vec, + }, +} + +pub type Result = std::result::Result; diff --git a/src/lib.rs b/src/lib.rs index 518fa071..dfda3430 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,11 @@ mod device; mod dtype; +mod error; mod op; mod storage; mod tensor; pub use device::Device; pub use dtype::DType; +pub use error::{Error, Result}; pub use tensor::Tensor; diff --git a/src/tensor.rs b/src/tensor.rs index 551d5998..034a1428 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,4 +1,4 @@ -use crate::{op::Op, storage::Storage, DType, Device}; +use crate::{op::Op, storage::Storage, DType, Device, Error, Result}; use std::sync::Arc; #[allow(dead_code)] @@ -46,4 +46,56 @@ impl Tensor { pub fn elem_count(&self) -> usize { self.0.shape.iter().product() } + + pub fn shape1(&self) -> Result { + let shape = self.shape(); + if shape.len() == 1 { + Ok(shape[0]) + } else { + Err(Error::UnexpectedNumberOfDims { + expected: 1, + got: shape.len(), + shape: shape.to_vec(), + }) + } + } + + pub fn shape2(&self) -> Result<(usize, usize)> { + let shape = self.shape(); + if shape.len() == 2 { + Ok((shape[0], shape[1])) + } else { + Err(Error::UnexpectedNumberOfDims { + expected: 2, + got: shape.len(), + shape: shape.to_vec(), + }) + } + } + + pub fn shape3(&self) -> Result<(usize, usize, usize)> { + let shape = self.shape(); + if shape.len() == 3 { + Ok((shape[0], shape[1], shape[2])) + } else { + Err(Error::UnexpectedNumberOfDims { + expected: 3, + got: shape.len(), + shape: shape.to_vec(), + }) + } + } + + pub fn shape4(&self) -> Result<(usize, usize, usize, usize)> { + let shape = self.shape(); + if shape.len() == 4 { + Ok((shape[0], shape[1], shape[2], shape[4])) + } else { + Err(Error::UnexpectedNumberOfDims { + expected: 4, + got: shape.len(), + shape: shape.to_vec(), + }) + } + } } diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs new file mode 100644 index 00000000..f772ab0a --- /dev/null +++ b/tests/tensor_tests.rs @@ -0,0 +1,10 @@ +use candle::{DType, Device, Result, Tensor}; + +#[test] +fn add() -> Result<()> { + let tensor = Tensor::zeros(&[5, 2], DType::F32, Device::Cpu); + let (dim1, dim2) = tensor.shape2()?; + assert_eq!(dim1, 5); + assert_eq!(dim2, 2); + Ok(()) +}