Add an easy way to create tensor objects.

This commit is contained in:
laurent
2023-06-19 20:59:26 +01:00
parent 01eeb0e72f
commit 26d6288eb6
6 changed files with 77 additions and 14 deletions

View File

@ -1,6 +1,6 @@
use crate::{
storage::{CpuStorage, Storage},
DType,
DType, Result, Shape,
};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
@ -8,11 +8,38 @@ pub enum Device {
Cpu,
}
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
pub trait NdArray {
fn shape(&self) -> Result<Shape>;
fn to_cpu_storage(&self) -> CpuStorage;
}
impl<S: crate::WithDType> NdArray for S {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(&[*self])
}
}
impl<S: crate::WithDType> NdArray for &[S] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(self)
}
}
impl Device {
pub(crate) fn zeros(&self, shape: &[usize], dtype: DType) -> Storage {
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
match self {
Device::Cpu => {
let elem_count: usize = shape.iter().product();
let elem_count = shape.elem_count();
let storage = match dtype {
DType::F32 => {
let data = vec![0f32; elem_count];
@ -27,4 +54,10 @@ impl Device {
}
}
}
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
match self {
Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
}
}
}