Start adding some ops.

This commit is contained in:
laurent
2023-06-20 08:41:19 +01:00
parent ef6760117f
commit 7a31ba93e4
5 changed files with 83 additions and 15 deletions

View File

@ -1,4 +1,4 @@
use crate::{op::Op, storage::Storage, DType, Device, Result, Shape};
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
use std::sync::Arc;
#[allow(dead_code)]
@ -45,11 +45,46 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
// TODO: properly use the strides here.
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<()> {
let lhs = self.shape();
let rhs = rhs.shape();
if lhs != rhs {
Err(Error::ShapeMismatchBinaryOp {
lhs: lhs.clone(),
rhs: rhs.clone(),
op,
})
} else {
Ok(())
}
}
pub fn add(&self, rhs: &Self) -> Result<Self> {
self.same_shape_binary_op(rhs, "add")?;
todo!()
}
pub fn mul(&self, rhs: &Self) -> Result<Self> {
self.same_shape_binary_op(rhs, "mul")?;
todo!()
}
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
if self.rank() != 0 {
return Err(Error::UnexpectedNumberOfDims {
expected: 0,
got: self.rank(),
shape: self.shape().clone(),
});
}
match &self.0.storage {
Storage::Cpu(cpu_storage) => {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(data[0])
}
}
}
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
// TODO: properly use the strides here.
todo!()