Start adding some ops.

This commit is contained in:
laurent
2023-06-20 08:41:19 +01:00
parent ef6760117f
commit 7a31ba93e4
5 changed files with 83 additions and 15 deletions

View File

@ -1,4 +1,4 @@
use crate::CpuStorage;
use crate::{CpuStorage, Error, Result};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
@ -19,6 +19,8 @@ pub trait WithDType: Sized + Copy {
const DTYPE: DType;
fn to_cpu_storage(data: &[Self]) -> CpuStorage;
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
}
impl WithDType for f32 {
@ -27,6 +29,16 @@ impl WithDType for f32 {
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
CpuStorage::F32(data.to_vec())
}
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
match s {
CpuStorage::F32(data) => Ok(data),
_ => Err(Error::UnexpectedDType {
expected: DType::F32,
got: s.dtype(),
}),
}
}
}
impl WithDType for f64 {
@ -35,4 +47,14 @@ impl WithDType for f64 {
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
CpuStorage::F64(data.to_vec())
}
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
match s {
CpuStorage::F64(data) => Ok(data),
_ => Err(Error::UnexpectedDType {
expected: DType::F64,
got: s.dtype(),
}),
}
}
}

View File

@ -1,10 +1,15 @@
use crate::{DType, Shape};
/// Main library error type.
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("invalid shapes in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
BinaryInvalidShape {
lhs: Vec<usize>,
rhs: Vec<usize>,
#[error("unexpected dtype, expected: {expected:?}, got: {got:?}")]
UnexpectedDType { expected: DType, got: DType },
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
ShapeMismatchBinaryOp {
lhs: Shape,
rhs: Shape,
op: &'static str,
},
@ -12,7 +17,7 @@ pub enum Error {
UnexpectedNumberOfDims {
expected: usize,
got: usize,
shape: Vec<usize>,
shape: Shape,
},
}

View File

@ -1,4 +1,6 @@
use crate::{Error, Result};
#[derive(Clone, PartialEq, Eq)]
pub struct Shape(pub(crate) Vec<usize>);
impl std::fmt::Debug for Shape {
@ -56,6 +58,10 @@ impl From<(usize, usize, usize)> for Shape {
}
impl Shape {
pub fn from_dims(dims: &[usize]) -> Self {
Self(dims.to_vec())
}
pub fn rank(&self) -> usize {
self.0.len()
}
@ -76,7 +82,7 @@ impl Shape {
Err(Error::UnexpectedNumberOfDims {
expected: 0,
got: shape.len(),
shape: shape.to_vec(),
shape: self.clone(),
})
}
}
@ -89,7 +95,7 @@ impl Shape {
Err(Error::UnexpectedNumberOfDims {
expected: 1,
got: shape.len(),
shape: shape.to_vec(),
shape: self.clone(),
})
}
}
@ -102,7 +108,7 @@ impl Shape {
Err(Error::UnexpectedNumberOfDims {
expected: 2,
got: shape.len(),
shape: shape.to_vec(),
shape: self.clone(),
})
}
}
@ -115,7 +121,7 @@ impl Shape {
Err(Error::UnexpectedNumberOfDims {
expected: 3,
got: shape.len(),
shape: shape.to_vec(),
shape: self.clone(),
})
}
}
@ -128,7 +134,7 @@ impl Shape {
Err(Error::UnexpectedNumberOfDims {
expected: 4,
got: shape.len(),
shape: shape.to_vec(),
shape: self.clone(),
})
}
}

View File

@ -1,4 +1,4 @@
use crate::{op::Op, storage::Storage, DType, Device, Result, Shape};
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
use std::sync::Arc;
#[allow(dead_code)]
@ -45,11 +45,46 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
// TODO: properly use the strides here.
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<()> {
let lhs = self.shape();
let rhs = rhs.shape();
if lhs != rhs {
Err(Error::ShapeMismatchBinaryOp {
lhs: lhs.clone(),
rhs: rhs.clone(),
op,
})
} else {
Ok(())
}
}
pub fn add(&self, rhs: &Self) -> Result<Self> {
self.same_shape_binary_op(rhs, "add")?;
todo!()
}
pub fn mul(&self, rhs: &Self) -> Result<Self> {
self.same_shape_binary_op(rhs, "mul")?;
todo!()
}
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
if self.rank() != 0 {
return Err(Error::UnexpectedNumberOfDims {
expected: 0,
got: self.rank(),
shape: self.shape().clone(),
});
}
match &self.0.storage {
Storage::Cpu(cpu_storage) => {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(data[0])
}
}
}
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
// TODO: properly use the strides here.
todo!()

View File

@ -6,7 +6,7 @@ fn add() -> Result<()> {
let (dim1, dim2) = tensor.shape().r2()?;
assert_eq!(dim1, 5);
assert_eq!(dim2, 2);
let tensor = Tensor::new([3., 1., 4.].as_slice(), Device::Cpu)?;
let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?;
let dim1 = tensor.shape().r1()?;
assert_eq!(dim1, 3);
Ok(())