Boilerplate code for the sum operator.

This commit is contained in:
laurent
2023-06-25 09:35:17 +01:00
parent 7ccf27dda2
commit 3852a85af3
7 changed files with 61 additions and 1 deletions

View File

@ -56,6 +56,7 @@ impl Tensor {
}
Op::Reshape(node)
| Op::Broadcast(node)
| Op::Sum(node, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@ -188,6 +189,9 @@ impl Tensor {
Op::Broadcast(_arg) => {
return Err(Error::BackwardNotSupported { op: "broadcast" })
}
Op::Sum(_arg, _sum_dims) => {
return Err(Error::BackwardNotSupported { op: "sum" })
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?

View File

@ -171,6 +171,15 @@ impl CpuStorage {
}
}
pub(crate) fn sum(
&self,
_shape: &Shape,
_stride: &[usize],
_sum_dims: &[usize],
) -> Result<Self> {
todo!()
}
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
// [self] stores data in a contiguous way.
let dims = shape.dims();

View File

@ -291,6 +291,15 @@ impl CudaStorage {
Ok(Self { slice, device })
}
pub(crate) fn sum(
&self,
_shape: &Shape,
_stride: &[usize],
_sum_dims: &[usize],
) -> Result<Self> {
todo!()
}
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
todo!()
}

View File

@ -64,6 +64,10 @@ impl CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn sum(&self, _: &Shape, _: &[usize], _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -21,6 +21,7 @@ pub(crate) enum Op {
mul: f64,
add: f64,
},
Sum(Tensor, Vec<usize>),
ToDType(Tensor),
Broadcast(Tensor),
Exp(Tensor),

View File

@ -72,6 +72,19 @@ impl Storage {
}
}
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.sum(shape, stride, s)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.sum(shape, stride, s)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
match self {
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,

View File

@ -102,7 +102,13 @@ macro_rules! broadcast_binary_op {
}
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
fn from_storage(storage: Storage, shape: Shape, op: Option<Op>, is_variable: bool) -> Tensor {
fn from_storage<S: Into<Shape>>(
storage: Storage,
shape: S,
op: Option<Op>,
is_variable: bool,
) -> Tensor {
let shape = shape.into();
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: TensorId::new(),
@ -347,6 +353,20 @@ impl Tensor {
Ok(from_storage(storage, shape.clone(), op, false))
}
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?;
let op = if self.track_op() {
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
} else {
None
};
let mut dims = self.dims().to_vec();
for &sum_dim in sum_dims.iter() {
dims[sum_dim] = 1
}
Ok(from_storage(storage, dims, op, false))
}
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
let a_dims = self.shape().dims();
let b_dims = rhs.shape().dims();