diff --git a/src/dtype.rs b/src/dtype.rs index b21aa208..a2c92aa7 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -1,4 +1,4 @@ -use crate::CpuStorage; +use crate::{CpuStorage, Error, Result}; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DType { @@ -19,6 +19,8 @@ pub trait WithDType: Sized + Copy { const DTYPE: DType; fn to_cpu_storage(data: &[Self]) -> CpuStorage; + + fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; } impl WithDType for f32 { @@ -27,6 +29,16 @@ impl WithDType for f32 { fn to_cpu_storage(data: &[Self]) -> CpuStorage { CpuStorage::F32(data.to_vec()) } + + fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { + match s { + CpuStorage::F32(data) => Ok(data), + _ => Err(Error::UnexpectedDType { + expected: DType::F32, + got: s.dtype(), + }), + } + } } impl WithDType for f64 { @@ -35,4 +47,14 @@ impl WithDType for f64 { fn to_cpu_storage(data: &[Self]) -> CpuStorage { CpuStorage::F64(data.to_vec()) } + + fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { + match s { + CpuStorage::F64(data) => Ok(data), + _ => Err(Error::UnexpectedDType { + expected: DType::F64, + got: s.dtype(), + }), + } + } } diff --git a/src/error.rs b/src/error.rs index 12386268..7416ed76 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,10 +1,15 @@ +use crate::{DType, Shape}; + /// Main library error type. #[derive(thiserror::Error, Debug)] pub enum Error { - #[error("invalid shapes in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] - BinaryInvalidShape { - lhs: Vec, - rhs: Vec, + #[error("unexpected dtype, expected: {expected:?}, got: {got:?}")] + UnexpectedDType { expected: DType, got: DType }, + + #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] + ShapeMismatchBinaryOp { + lhs: Shape, + rhs: Shape, op: &'static str, }, @@ -12,7 +17,7 @@ pub enum Error { UnexpectedNumberOfDims { expected: usize, got: usize, - shape: Vec, + shape: Shape, }, } diff --git a/src/shape.rs b/src/shape.rs index a5fee614..4b186ca0 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -1,4 +1,6 @@ use crate::{Error, Result}; + +#[derive(Clone, PartialEq, Eq)] pub struct Shape(pub(crate) Vec); impl std::fmt::Debug for Shape { @@ -56,6 +58,10 @@ impl From<(usize, usize, usize)> for Shape { } impl Shape { + pub fn from_dims(dims: &[usize]) -> Self { + Self(dims.to_vec()) + } + pub fn rank(&self) -> usize { self.0.len() } @@ -76,7 +82,7 @@ impl Shape { Err(Error::UnexpectedNumberOfDims { expected: 0, got: shape.len(), - shape: shape.to_vec(), + shape: self.clone(), }) } } @@ -89,7 +95,7 @@ impl Shape { Err(Error::UnexpectedNumberOfDims { expected: 1, got: shape.len(), - shape: shape.to_vec(), + shape: self.clone(), }) } } @@ -102,7 +108,7 @@ impl Shape { Err(Error::UnexpectedNumberOfDims { expected: 2, got: shape.len(), - shape: shape.to_vec(), + shape: self.clone(), }) } } @@ -115,7 +121,7 @@ impl Shape { Err(Error::UnexpectedNumberOfDims { expected: 3, got: shape.len(), - shape: shape.to_vec(), + shape: self.clone(), }) } } @@ -128,7 +134,7 @@ impl Shape { Err(Error::UnexpectedNumberOfDims { expected: 4, got: shape.len(), - shape: shape.to_vec(), + shape: self.clone(), }) } } diff --git a/src/tensor.rs b/src/tensor.rs index 9b1e7d5b..83aa00d7 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,4 +1,4 @@ -use crate::{op::Op, storage::Storage, DType, Device, Result, Shape}; +use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape}; use std::sync::Arc; #[allow(dead_code)] @@ -45,11 +45,46 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } - pub fn to_scalar(&self) -> Result { - // TODO: properly use the strides here. + pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<()> { + let lhs = self.shape(); + let rhs = rhs.shape(); + if lhs != rhs { + Err(Error::ShapeMismatchBinaryOp { + lhs: lhs.clone(), + rhs: rhs.clone(), + op, + }) + } else { + Ok(()) + } + } + + pub fn add(&self, rhs: &Self) -> Result { + self.same_shape_binary_op(rhs, "add")?; todo!() } + pub fn mul(&self, rhs: &Self) -> Result { + self.same_shape_binary_op(rhs, "mul")?; + todo!() + } + + pub fn to_scalar(&self) -> Result { + if self.rank() != 0 { + return Err(Error::UnexpectedNumberOfDims { + expected: 0, + got: self.rank(), + shape: self.shape().clone(), + }); + } + match &self.0.storage { + Storage::Cpu(cpu_storage) => { + let data = S::cpu_storage_as_slice(cpu_storage)?; + Ok(data[0]) + } + } + } + pub fn to_vec1(&self) -> Result> { // TODO: properly use the strides here. todo!() diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 54c1e987..a0f4630d 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -6,7 +6,7 @@ fn add() -> Result<()> { let (dim1, dim2) = tensor.shape().r2()?; assert_eq!(dim1, 5); assert_eq!(dim2, 2); - let tensor = Tensor::new([3., 1., 4.].as_slice(), Device::Cpu)?; + let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?; let dim1 = tensor.shape().r1()?; assert_eq!(dim1, 3); Ok(())