From 01eeb0e72f07d3e7dbcf2d62720c2b7f2665e55a Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 19 Jun 2023 20:22:12 +0100 Subject: [PATCH] Shuffle the shape bits around. --- src/dtype.rs | 12 ++++ src/lib.rs | 3 +- src/shape.rs | 129 ++++++++++++++++++++++++++++++++++++++++++ src/tensor.rs | 76 +++++-------------------- tests/tensor_tests.rs | 2 +- 5 files changed, 159 insertions(+), 63 deletions(-) create mode 100644 src/shape.rs diff --git a/src/dtype.rs b/src/dtype.rs index 4d722e9d..761b21bd 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -12,3 +12,15 @@ impl DType { } } } + +pub trait WithDType { + const DTYPE: DType; +} + +impl WithDType for f32 { + const DTYPE: DType = DType::F32; +} + +impl WithDType for f64 { + const DTYPE: DType = DType::F64; +} diff --git a/src/lib.rs b/src/lib.rs index dfda3430..f1a73c5f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,10 +2,11 @@ mod device; mod dtype; mod error; mod op; +mod shape; mod storage; mod tensor; pub use device::Device; -pub use dtype::DType; +pub use dtype::{DType, WithDType}; pub use error::{Error, Result}; pub use tensor::Tensor; diff --git a/src/shape.rs b/src/shape.rs new file mode 100644 index 00000000..d9112aa7 --- /dev/null +++ b/src/shape.rs @@ -0,0 +1,129 @@ +use crate::{Error, Result}; +pub struct Shape(pub(crate) Vec); + +impl From<&[usize; 1]> for Shape { + fn from(dims: &[usize; 1]) -> Self { + Self(dims.to_vec()) + } +} + +impl From<&[usize; 2]> for Shape { + fn from(dims: &[usize; 2]) -> Self { + Self(dims.to_vec()) + } +} + +impl From<&[usize; 3]> for Shape { + fn from(dims: &[usize; 3]) -> Self { + Self(dims.to_vec()) + } +} + +impl From<&[usize]> for Shape { + fn from(dims: &[usize]) -> Self { + Self(dims.to_vec()) + } +} + +impl From<()> for Shape { + fn from(_: ()) -> Self { + Self(vec![]) + } +} + +impl From for Shape { + fn from(d1: usize) -> Self { + Self(vec![d1]) + } +} + +impl From<(usize, usize)> for Shape { + fn from(d12: (usize, usize)) -> Self { + Self(vec![d12.0, d12.1]) + } +} + +impl From<(usize, usize, usize)> for Shape { + fn from(d123: (usize, usize, usize)) -> Self { + Self(vec![d123.0, d123.1, d123.2]) + } +} + +impl Shape { + pub fn rank(&self) -> usize { + self.0.len() + } + + pub fn dims(&self) -> &[usize] { + &self.0 + } + + pub fn elem_count(&self) -> usize { + self.0.iter().product() + } + + pub fn r0(&self) -> Result<()> { + let shape = &self.0; + if shape.is_empty() { + Ok(()) + } else { + Err(Error::UnexpectedNumberOfDims { + expected: 0, + got: shape.len(), + shape: shape.to_vec(), + }) + } + } + + pub fn r1(&self) -> Result { + let shape = &self.0; + if shape.len() == 1 { + Ok(shape[0]) + } else { + Err(Error::UnexpectedNumberOfDims { + expected: 1, + got: shape.len(), + shape: shape.to_vec(), + }) + } + } + + pub fn r2(&self) -> Result<(usize, usize)> { + let shape = &self.0; + if shape.len() == 2 { + Ok((shape[0], shape[1])) + } else { + Err(Error::UnexpectedNumberOfDims { + expected: 2, + got: shape.len(), + shape: shape.to_vec(), + }) + } + } + + pub fn r3(&self) -> Result<(usize, usize, usize)> { + let shape = &self.0; + if shape.len() == 3 { + Ok((shape[0], shape[1], shape[2])) + } else { + Err(Error::UnexpectedNumberOfDims { + expected: 3, + got: shape.len(), + shape: shape.to_vec(), + }) + } + } + + pub fn r4(&self) -> Result<(usize, usize, usize, usize)> { + let shape = &self.0; + if shape.len() == 4 { + Ok((shape[0], shape[1], shape[2], shape[4])) + } else { + Err(Error::UnexpectedNumberOfDims { + expected: 4, + got: shape.len(), + shape: shape.to_vec(), + }) + } + } +} diff --git a/src/tensor.rs b/src/tensor.rs index 034a1428..99fb2cf0 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,10 +1,10 @@ -use crate::{op::Op, storage::Storage, DType, Device, Error, Result}; +use crate::{op::Op, shape, storage::Storage, DType, Device}; use std::sync::Arc; #[allow(dead_code)] pub(crate) struct Tensor_ { storage: Storage, - shape: Vec, + shape: shape::Shape, stride: Vec, op: Option, } @@ -12,12 +12,14 @@ pub(crate) struct Tensor_ { pub struct Tensor(Arc); impl Tensor { - pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self { - let storage = device.zeros(shape, dtype); + pub fn zeros>(shape: S, dtype: DType, device: Device) -> Self { + let shape = shape.into(); + let storage = device.zeros(&shape.0, dtype); + let rank = shape.0.len(); let tensor_ = Tensor_ { storage, - shape: shape.to_vec(), - stride: vec![1; shape.len()], + shape, + stride: vec![1; rank], op: None, }; Tensor(Arc::new(tensor_)) @@ -31,71 +33,23 @@ impl Tensor { self.0.storage.device() } - pub fn shape(&self) -> &[usize] { + pub fn shape(&self) -> &shape::Shape { &self.0.shape } + pub fn dims(&self) -> &[usize] { + &self.shape().dims() + } + pub fn stride(&self) -> &[usize] { &self.0.stride } pub fn rank(&self) -> usize { - self.0.shape.len() + self.shape().rank() } pub fn elem_count(&self) -> usize { - self.0.shape.iter().product() - } - - pub fn shape1(&self) -> Result { - let shape = self.shape(); - if shape.len() == 1 { - Ok(shape[0]) - } else { - Err(Error::UnexpectedNumberOfDims { - expected: 1, - got: shape.len(), - shape: shape.to_vec(), - }) - } - } - - pub fn shape2(&self) -> Result<(usize, usize)> { - let shape = self.shape(); - if shape.len() == 2 { - Ok((shape[0], shape[1])) - } else { - Err(Error::UnexpectedNumberOfDims { - expected: 2, - got: shape.len(), - shape: shape.to_vec(), - }) - } - } - - pub fn shape3(&self) -> Result<(usize, usize, usize)> { - let shape = self.shape(); - if shape.len() == 3 { - Ok((shape[0], shape[1], shape[2])) - } else { - Err(Error::UnexpectedNumberOfDims { - expected: 3, - got: shape.len(), - shape: shape.to_vec(), - }) - } - } - - pub fn shape4(&self) -> Result<(usize, usize, usize, usize)> { - let shape = self.shape(); - if shape.len() == 4 { - Ok((shape[0], shape[1], shape[2], shape[4])) - } else { - Err(Error::UnexpectedNumberOfDims { - expected: 4, - got: shape.len(), - shape: shape.to_vec(), - }) - } + self.shape().elem_count() } } diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index f772ab0a..4b94f40d 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -3,7 +3,7 @@ use candle::{DType, Device, Result, Tensor}; #[test] fn add() -> Result<()> { let tensor = Tensor::zeros(&[5, 2], DType::F32, Device::Cpu); - let (dim1, dim2) = tensor.shape2()?; + let (dim1, dim2) = tensor.shape().r2()?; assert_eq!(dim1, 5); assert_eq!(dim2, 2); Ok(())