Add the op.

This commit is contained in:
laurent
2023-06-19 18:34:54 +01:00
parent 844704de5c
commit ce718bb807
5 changed files with 45 additions and 13 deletions

View File

@ -1,4 +1,4 @@
use crate::{DType, Storage};
use crate::{storage::Storage, DType};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Device {

View File

@ -1,9 +1,9 @@
mod device;
mod dtype;
mod op;
mod storage;
mod tensor;
pub use device::Device;
pub use dtype::DType;
use storage::Storage;
pub use tensor::Tensor;

7
src/op.rs Normal file
View File

@ -0,0 +1,7 @@
use crate::Tensor;
#[allow(dead_code)]
pub(crate) enum Op {
Add(Tensor, Tensor),
Mul(Tensor, Tensor),
}

View File

@ -1,3 +1,5 @@
use crate::{DType, Device};
#[allow(dead_code)]
pub(crate) enum Storage {
Cpu {
@ -5,3 +7,17 @@ pub(crate) enum Storage {
buffer: Vec<u8>,
},
}
impl Storage {
pub(crate) fn device(&self) -> Device {
match self {
Self::Cpu { .. } => Device::Cpu,
}
}
pub(crate) fn dtype(&self) -> DType {
match self {
Self::Cpu { dtype, .. } => *dtype,
}
}
}

View File

@ -1,40 +1,49 @@
use crate::{DType, Device, Storage};
use crate::{op::Op, storage::Storage, DType, Device};
use std::sync::Arc;
pub struct Tensor {
#[allow(dead_code)]
pub(crate) struct Tensor_ {
storage: Storage,
shape: Vec<usize>,
stride: Vec<usize>,
op: Option<Op>,
}
pub struct Tensor(Arc<Tensor_>);
impl Tensor {
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
let storage = device.zeros(shape, dtype);
Self {
let tensor_ = Tensor_ {
storage,
shape: shape.to_vec(),
stride: vec![1; shape.len()],
}
op: None,
};
Tensor(Arc::new(tensor_))
}
pub fn dtype(&self) -> DType {
self.0.storage.dtype()
}
pub fn device(&self) -> Device {
match self.storage {
Storage::Cpu { .. } => Device::Cpu,
}
self.0.storage.device()
}
pub fn shape(&self) -> &[usize] {
&self.shape
&self.0.shape
}
pub fn stride(&self) -> &[usize] {
&self.stride
&self.0.stride
}
pub fn rank(&self) -> usize {
self.shape.len()
self.0.shape.len()
}
pub fn elem_count(&self) -> usize {
self.shape.iter().product()
self.0.shape.iter().product()
}
}