From 844704de5c461279dc906eb82056c1adcf592b18 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 19 Jun 2023 17:34:13 +0100 Subject: [PATCH] Split the tensor file. --- src/device.rs | 18 ++++++++++++++++++ src/dtype.rs | 14 ++++++++++++++ src/lib.rs | 8 +++++++- src/storage.rs | 7 +++++++ src/tensor.rs | 33 ++------------------------------- 5 files changed, 48 insertions(+), 32 deletions(-) create mode 100644 src/device.rs create mode 100644 src/dtype.rs create mode 100644 src/storage.rs diff --git a/src/device.rs b/src/device.rs new file mode 100644 index 00000000..1e74fd76 --- /dev/null +++ b/src/device.rs @@ -0,0 +1,18 @@ +use crate::{DType, Storage}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum Device { + Cpu, +} + +impl Device { + pub(crate) fn zeros(&self, shape: &[usize], dtype: DType) -> Storage { + match self { + Device::Cpu => { + let elem_count: usize = shape.iter().product(); + let buffer = vec![0; elem_count * dtype.size_in_bytes()]; + Storage::Cpu { dtype, buffer } + } + } + } +} diff --git a/src/dtype.rs b/src/dtype.rs new file mode 100644 index 00000000..4d722e9d --- /dev/null +++ b/src/dtype.rs @@ -0,0 +1,14 @@ +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum DType { + F32, + F64, +} + +impl DType { + pub fn size_in_bytes(&self) -> usize { + match self { + Self::F32 => 4, + Self::F64 => 8, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 578f86ac..8582da03 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,9 @@ +mod device; +mod dtype; +mod storage; mod tensor; -pub use tensor::{DType, Tensor}; +pub use device::Device; +pub use dtype::DType; +use storage::Storage; +pub use tensor::Tensor; diff --git a/src/storage.rs b/src/storage.rs new file mode 100644 index 00000000..056c4693 --- /dev/null +++ b/src/storage.rs @@ -0,0 +1,7 @@ +#[allow(dead_code)] +pub(crate) enum Storage { + Cpu { + dtype: crate::DType, + buffer: Vec, + }, +} diff --git a/src/tensor.rs b/src/tensor.rs index c9b4a056..e3560a67 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,17 +1,4 @@ -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum DType { - F32, - F64, -} - -impl DType { - pub fn size_in_bytes(&self) -> usize { - match self { - Self::F32 => 4, - Self::F64 => 8, - } - } -} +use crate::{DType, Device, Storage}; pub struct Tensor { storage: Storage, @@ -19,25 +6,9 @@ pub struct Tensor { stride: Vec, } -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum Device { - Cpu, -} - -#[allow(dead_code)] -enum Storage { - Cpu { dtype: DType, buffer: Vec }, -} - impl Tensor { pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self { - let storage = match device { - Device::Cpu => { - let elem_count: usize = shape.iter().product(); - let buffer = vec![0; elem_count * dtype.size_in_bytes()]; - Storage::Cpu { dtype, buffer } - } - }; + let storage = device.zeros(shape, dtype); Self { storage, shape: shape.to_vec(),