Start adding some ops.

This commit is contained in:
laurent
2023-06-20 08:41:19 +01:00
parent ef6760117f
commit 7a31ba93e4
5 changed files with 83 additions and 15 deletions

View File

@ -1,4 +1,4 @@
use crate::CpuStorage; use crate::{CpuStorage, Error, Result};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType { pub enum DType {
@ -19,6 +19,8 @@ pub trait WithDType: Sized + Copy {
const DTYPE: DType; const DTYPE: DType;
fn to_cpu_storage(data: &[Self]) -> CpuStorage; fn to_cpu_storage(data: &[Self]) -> CpuStorage;
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
} }
impl WithDType for f32 { impl WithDType for f32 {
@ -27,6 +29,16 @@ impl WithDType for f32 {
fn to_cpu_storage(data: &[Self]) -> CpuStorage { fn to_cpu_storage(data: &[Self]) -> CpuStorage {
CpuStorage::F32(data.to_vec()) CpuStorage::F32(data.to_vec())
} }
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
match s {
CpuStorage::F32(data) => Ok(data),
_ => Err(Error::UnexpectedDType {
expected: DType::F32,
got: s.dtype(),
}),
}
}
} }
impl WithDType for f64 { impl WithDType for f64 {
@ -35,4 +47,14 @@ impl WithDType for f64 {
fn to_cpu_storage(data: &[Self]) -> CpuStorage { fn to_cpu_storage(data: &[Self]) -> CpuStorage {
CpuStorage::F64(data.to_vec()) CpuStorage::F64(data.to_vec())
} }
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
match s {
CpuStorage::F64(data) => Ok(data),
_ => Err(Error::UnexpectedDType {
expected: DType::F64,
got: s.dtype(),
}),
}
}
} }

View File

@ -1,10 +1,15 @@
use crate::{DType, Shape};
/// Main library error type. /// Main library error type.
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum Error { pub enum Error {
#[error("invalid shapes in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] #[error("unexpected dtype, expected: {expected:?}, got: {got:?}")]
BinaryInvalidShape { UnexpectedDType { expected: DType, got: DType },
lhs: Vec<usize>,
rhs: Vec<usize>, #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
ShapeMismatchBinaryOp {
lhs: Shape,
rhs: Shape,
op: &'static str, op: &'static str,
}, },
@ -12,7 +17,7 @@ pub enum Error {
UnexpectedNumberOfDims { UnexpectedNumberOfDims {
expected: usize, expected: usize,
got: usize, got: usize,
shape: Vec<usize>, shape: Shape,
}, },
} }

View File

@ -1,4 +1,6 @@
use crate::{Error, Result}; use crate::{Error, Result};
#[derive(Clone, PartialEq, Eq)]
pub struct Shape(pub(crate) Vec<usize>); pub struct Shape(pub(crate) Vec<usize>);
impl std::fmt::Debug for Shape { impl std::fmt::Debug for Shape {
@ -56,6 +58,10 @@ impl From<(usize, usize, usize)> for Shape {
} }
impl Shape { impl Shape {
pub fn from_dims(dims: &[usize]) -> Self {
Self(dims.to_vec())
}
pub fn rank(&self) -> usize { pub fn rank(&self) -> usize {
self.0.len() self.0.len()
} }
@ -76,7 +82,7 @@ impl Shape {
Err(Error::UnexpectedNumberOfDims { Err(Error::UnexpectedNumberOfDims {
expected: 0, expected: 0,
got: shape.len(), got: shape.len(),
shape: shape.to_vec(), shape: self.clone(),
}) })
} }
} }
@ -89,7 +95,7 @@ impl Shape {
Err(Error::UnexpectedNumberOfDims { Err(Error::UnexpectedNumberOfDims {
expected: 1, expected: 1,
got: shape.len(), got: shape.len(),
shape: shape.to_vec(), shape: self.clone(),
}) })
} }
} }
@ -102,7 +108,7 @@ impl Shape {
Err(Error::UnexpectedNumberOfDims { Err(Error::UnexpectedNumberOfDims {
expected: 2, expected: 2,
got: shape.len(), got: shape.len(),
shape: shape.to_vec(), shape: self.clone(),
}) })
} }
} }
@ -115,7 +121,7 @@ impl Shape {
Err(Error::UnexpectedNumberOfDims { Err(Error::UnexpectedNumberOfDims {
expected: 3, expected: 3,
got: shape.len(), got: shape.len(),
shape: shape.to_vec(), shape: self.clone(),
}) })
} }
} }
@ -128,7 +134,7 @@ impl Shape {
Err(Error::UnexpectedNumberOfDims { Err(Error::UnexpectedNumberOfDims {
expected: 4, expected: 4,
got: shape.len(), got: shape.len(),
shape: shape.to_vec(), shape: self.clone(),
}) })
} }
} }

View File

@ -1,4 +1,4 @@
use crate::{op::Op, storage::Storage, DType, Device, Result, Shape}; use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
use std::sync::Arc; use std::sync::Arc;
#[allow(dead_code)] #[allow(dead_code)]
@ -45,11 +45,46 @@ impl Tensor {
Ok(Self(Arc::new(tensor_))) Ok(Self(Arc::new(tensor_)))
} }
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> { pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<()> {
// TODO: properly use the strides here. let lhs = self.shape();
let rhs = rhs.shape();
if lhs != rhs {
Err(Error::ShapeMismatchBinaryOp {
lhs: lhs.clone(),
rhs: rhs.clone(),
op,
})
} else {
Ok(())
}
}
pub fn add(&self, rhs: &Self) -> Result<Self> {
self.same_shape_binary_op(rhs, "add")?;
todo!() todo!()
} }
pub fn mul(&self, rhs: &Self) -> Result<Self> {
self.same_shape_binary_op(rhs, "mul")?;
todo!()
}
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
if self.rank() != 0 {
return Err(Error::UnexpectedNumberOfDims {
expected: 0,
got: self.rank(),
shape: self.shape().clone(),
});
}
match &self.0.storage {
Storage::Cpu(cpu_storage) => {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(data[0])
}
}
}
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> { pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
// TODO: properly use the strides here. // TODO: properly use the strides here.
todo!() todo!()

View File

@ -6,7 +6,7 @@ fn add() -> Result<()> {
let (dim1, dim2) = tensor.shape().r2()?; let (dim1, dim2) = tensor.shape().r2()?;
assert_eq!(dim1, 5); assert_eq!(dim1, 5);
assert_eq!(dim2, 2); assert_eq!(dim2, 2);
let tensor = Tensor::new([3., 1., 4.].as_slice(), Device::Cpu)?; let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?;
let dim1 = tensor.shape().r1()?; let dim1 = tensor.shape().r1()?;
assert_eq!(dim1, 3); assert_eq!(dim1, 3);
Ok(()) Ok(())