Flesh out some ops bits.

This commit is contained in:
laurent
2023-06-19 19:28:33 +01:00
parent ce718bb807
commit 8e2c534d1f
3 changed files with 35 additions and 9 deletions

View File

@ -1,4 +1,7 @@
use crate::{storage::Storage, DType};
use crate::{
storage::{CpuStorage, Storage},
DType,
};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Device {
@ -10,8 +13,17 @@ impl Device {
match self {
Device::Cpu => {
let elem_count: usize = shape.iter().product();
let buffer = vec![0; elem_count * dtype.size_in_bytes()];
Storage::Cpu { dtype, buffer }
let storage = match dtype {
DType::F32 => {
let data = vec![0f32; elem_count];
CpuStorage::F32(data)
}
DType::F64 => {
let data = vec![0f64; elem_count];
CpuStorage::F64(data)
}
};
Storage::Cpu(storage)
}
}
}

View File

@ -4,4 +4,5 @@ use crate::Tensor;
pub(crate) enum Op {
Add(Tensor, Tensor),
Mul(Tensor, Tensor),
// TODO: Support for custom ops.
}

View File

@ -1,23 +1,36 @@
use crate::{DType, Device};
// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
pub(crate) enum CpuStorage {
F32(Vec<f32>),
F64(Vec<f64>),
}
impl CpuStorage {
pub(crate) fn dtype(&self) -> DType {
match self {
Self::F32(_) => DType::F32,
Self::F64(_) => DType::F64,
}
}
}
#[allow(dead_code)]
pub(crate) enum Storage {
Cpu {
dtype: crate::DType,
buffer: Vec<u8>,
},
Cpu(CpuStorage),
}
impl Storage {
pub(crate) fn device(&self) -> Device {
match self {
Self::Cpu { .. } => Device::Cpu,
Self::Cpu(_) => Device::Cpu,
}
}
pub(crate) fn dtype(&self) -> DType {
match self {
Self::Cpu { dtype, .. } => *dtype,
Self::Cpu(storage) => storage.dtype(),
}
}
}