Split the tensor file.

This commit is contained in:
laurent
2023-06-19 17:34:13 +01:00
parent 9698211d56
commit 844704de5c
5 changed files with 48 additions and 32 deletions

18
src/device.rs Normal file
View File

@ -0,0 +1,18 @@
use crate::{DType, Storage};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Device {
Cpu,
}
impl Device {
pub(crate) fn zeros(&self, shape: &[usize], dtype: DType) -> Storage {
match self {
Device::Cpu => {
let elem_count: usize = shape.iter().product();
let buffer = vec![0; elem_count * dtype.size_in_bytes()];
Storage::Cpu { dtype, buffer }
}
}
}
}

14
src/dtype.rs Normal file
View File

@ -0,0 +1,14 @@
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
F32,
F64,
}
impl DType {
pub fn size_in_bytes(&self) -> usize {
match self {
Self::F32 => 4,
Self::F64 => 8,
}
}
}

View File

@ -1,3 +1,9 @@
mod device;
mod dtype;
mod storage;
mod tensor; mod tensor;
pub use tensor::{DType, Tensor}; pub use device::Device;
pub use dtype::DType;
use storage::Storage;
pub use tensor::Tensor;

7
src/storage.rs Normal file
View File

@ -0,0 +1,7 @@
#[allow(dead_code)]
pub(crate) enum Storage {
Cpu {
dtype: crate::DType,
buffer: Vec<u8>,
},
}

View File

@ -1,17 +1,4 @@
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] use crate::{DType, Device, Storage};
pub enum DType {
F32,
F64,
}
impl DType {
pub fn size_in_bytes(&self) -> usize {
match self {
Self::F32 => 4,
Self::F64 => 8,
}
}
}
pub struct Tensor { pub struct Tensor {
storage: Storage, storage: Storage,
@ -19,25 +6,9 @@ pub struct Tensor {
stride: Vec<usize>, stride: Vec<usize>,
} }
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Device {
Cpu,
}
#[allow(dead_code)]
enum Storage {
Cpu { dtype: DType, buffer: Vec<u8> },
}
impl Tensor { impl Tensor {
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self { pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
let storage = match device { let storage = device.zeros(shape, dtype);
Device::Cpu => {
let elem_count: usize = shape.iter().product();
let buffer = vec![0; elem_count * dtype.size_in_bytes()];
Storage::Cpu { dtype, buffer }
}
};
Self { Self {
storage, storage,
shape: shape.to_vec(), shape: shape.to_vec(),