From 26d6288eb6175e202433e1830dc87ac566831bc6 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 19 Jun 2023 20:59:26 +0100 Subject: [PATCH] Add an easy way to create tensor objects. --- src/device.rs | 39 ++++++++++++++++++++++++++++++++++++--- src/dtype.rs | 14 +++++++++++++- src/lib.rs | 2 ++ src/storage.rs | 8 ++++---- src/tensor.rs | 23 ++++++++++++++++++----- tests/tensor_tests.rs | 5 ++++- 6 files changed, 77 insertions(+), 14 deletions(-) diff --git a/src/device.rs b/src/device.rs index b0249838..6741e582 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,6 +1,6 @@ use crate::{ storage::{CpuStorage, Storage}, - DType, + DType, Result, Shape, }; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] @@ -8,11 +8,38 @@ pub enum Device { Cpu, } +// TODO: Should we back the cpu implementation using the NdArray crate or similar? +pub trait NdArray { + fn shape(&self) -> Result; + + fn to_cpu_storage(&self) -> CpuStorage; +} + +impl NdArray for S { + fn shape(&self) -> Result { + Ok(Shape::from(())) + } + + fn to_cpu_storage(&self) -> CpuStorage { + S::to_cpu_storage(&[*self]) + } +} + +impl NdArray for &[S] { + fn shape(&self) -> Result { + Ok(Shape::from(self.len())) + } + + fn to_cpu_storage(&self) -> CpuStorage { + S::to_cpu_storage(self) + } +} + impl Device { - pub(crate) fn zeros(&self, shape: &[usize], dtype: DType) -> Storage { + pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage { match self { Device::Cpu => { - let elem_count: usize = shape.iter().product(); + let elem_count = shape.elem_count(); let storage = match dtype { DType::F32 => { let data = vec![0f32; elem_count]; @@ -27,4 +54,10 @@ impl Device { } } } + + pub(crate) fn tensor(&self, array: A) -> Storage { + match self { + Device::Cpu => Storage::Cpu(array.to_cpu_storage()), + } + } } diff --git a/src/dtype.rs b/src/dtype.rs index 761b21bd..b21aa208 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -1,3 +1,5 @@ +use crate::CpuStorage; + #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DType { F32, @@ -13,14 +15,24 @@ impl DType { } } -pub trait WithDType { +pub trait WithDType: Sized + Copy { const DTYPE: DType; + + fn to_cpu_storage(data: &[Self]) -> CpuStorage; } impl WithDType for f32 { const DTYPE: DType = DType::F32; + + fn to_cpu_storage(data: &[Self]) -> CpuStorage { + CpuStorage::F32(data.to_vec()) + } } impl WithDType for f64 { const DTYPE: DType = DType::F64; + + fn to_cpu_storage(data: &[Self]) -> CpuStorage { + CpuStorage::F64(data.to_vec()) + } } diff --git a/src/lib.rs b/src/lib.rs index f1a73c5f..95373540 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,4 +9,6 @@ mod tensor; pub use device::Device; pub use dtype::{DType, WithDType}; pub use error::{Error, Result}; +pub use shape::Shape; +pub use storage::{CpuStorage, Storage}; pub use tensor::Tensor; diff --git a/src/storage.rs b/src/storage.rs index f9308633..10502d43 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -2,7 +2,7 @@ use crate::{DType, Device}; // TODO: Think about whether we would be better off with a dtype and // a buffer as an owned slice of bytes. -pub(crate) enum CpuStorage { +pub enum CpuStorage { F32(Vec), F64(Vec), } @@ -17,18 +17,18 @@ impl CpuStorage { } #[allow(dead_code)] -pub(crate) enum Storage { +pub enum Storage { Cpu(CpuStorage), } impl Storage { - pub(crate) fn device(&self) -> Device { + pub fn device(&self) -> Device { match self { Self::Cpu(_) => Device::Cpu, } } - pub(crate) fn dtype(&self) -> DType { + pub fn dtype(&self) -> DType { match self { Self::Cpu(storage) => storage.dtype(), } diff --git a/src/tensor.rs b/src/tensor.rs index 99fb2cf0..64438ce6 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,4 +1,4 @@ -use crate::{op::Op, shape, storage::Storage, DType, Device}; +use crate::{op::Op, shape, storage::Storage, DType, Device, Result}; use std::sync::Arc; #[allow(dead_code)] @@ -14,15 +14,28 @@ pub struct Tensor(Arc); impl Tensor { pub fn zeros>(shape: S, dtype: DType, device: Device) -> Self { let shape = shape.into(); - let storage = device.zeros(&shape.0, dtype); - let rank = shape.0.len(); + let storage = device.zeros(&shape, dtype); + let rank = shape.rank(); let tensor_ = Tensor_ { storage, shape, stride: vec![1; rank], op: None, }; - Tensor(Arc::new(tensor_)) + Self(Arc::new(tensor_)) + } + + pub fn new(array: A, device: Device) -> Result { + let shape = array.shape()?; + let storage = device.tensor(array); + let rank = shape.rank(); + let tensor_ = Tensor_ { + storage, + shape, + stride: vec![1; rank], + op: None, + }; + Ok(Self(Arc::new(tensor_))) } pub fn dtype(&self) -> DType { @@ -38,7 +51,7 @@ impl Tensor { } pub fn dims(&self) -> &[usize] { - &self.shape().dims() + self.shape().dims() } pub fn stride(&self) -> &[usize] { diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 4b94f40d..54c1e987 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -2,9 +2,12 @@ use candle::{DType, Device, Result, Tensor}; #[test] fn add() -> Result<()> { - let tensor = Tensor::zeros(&[5, 2], DType::F32, Device::Cpu); + let tensor = Tensor::zeros((5, 2), DType::F32, Device::Cpu); let (dim1, dim2) = tensor.shape().r2()?; assert_eq!(dim1, 5); assert_eq!(dim2, 2); + let tensor = Tensor::new([3., 1., 4.].as_slice(), Device::Cpu)?; + let dim1 = tensor.shape().r1()?; + assert_eq!(dim1, 3); Ok(()) }