diff --git a/src/backprop.rs b/src/backprop.rs index ca463cdb..072a9005 100644 --- a/src/backprop.rs +++ b/src/backprop.rs @@ -56,6 +56,7 @@ impl Tensor { } Op::Reshape(node) | Op::Broadcast(node) + | Op::Sum(node, _) | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) @@ -188,6 +189,9 @@ impl Tensor { Op::Broadcast(_arg) => { return Err(Error::BackwardNotSupported { op: "broadcast" }) } + Op::Sum(_arg, _sum_dims) => { + return Err(Error::BackwardNotSupported { op: "sum" }) + } Op::ToDType(arg) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)? diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 3e4e1826..8c57cce3 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -171,6 +171,15 @@ impl CpuStorage { } } + pub(crate) fn sum( + &self, + _shape: &Shape, + _stride: &[usize], + _sum_dims: &[usize], + ) -> Result { + todo!() + } + pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { // [self] stores data in a contiguous way. let dims = shape.dims(); diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 70084821..fdfca801 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -291,6 +291,15 @@ impl CudaStorage { Ok(Self { slice, device }) } + pub(crate) fn sum( + &self, + _shape: &Shape, + _stride: &[usize], + _sum_dims: &[usize], + ) -> Result { + todo!() + } + pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { todo!() } diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index 08563e87..98762277 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -64,6 +64,10 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub(crate) fn sum(&self, _: &Shape, _: &[usize], _: &[usize]) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/src/op.rs b/src/op.rs index fa1373f6..1d93eee8 100644 --- a/src/op.rs +++ b/src/op.rs @@ -21,6 +21,7 @@ pub(crate) enum Op { mul: f64, add: f64, }, + Sum(Tensor, Vec), ToDType(Tensor), Broadcast(Tensor), Exp(Tensor), diff --git a/src/storage.rs b/src/storage.rs index 21992992..c13a01a6 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -72,6 +72,19 @@ impl Storage { } } + pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.sum(shape, stride, s)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.sum(shape, stride, s)?; + Ok(Self::Cuda(storage)) + } + } + } + pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { match self { Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?, diff --git a/src/tensor.rs b/src/tensor.rs index fb8c0960..c206ae30 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -102,7 +102,13 @@ macro_rules! broadcast_binary_op { } /// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides. -fn from_storage(storage: Storage, shape: Shape, op: Option, is_variable: bool) -> Tensor { +fn from_storage>( + storage: Storage, + shape: S, + op: Option, + is_variable: bool, +) -> Tensor { + let shape = shape.into(); let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), @@ -347,6 +353,20 @@ impl Tensor { Ok(from_storage(storage, shape.clone(), op, false)) } + pub fn sum(&self, sum_dims: &[usize]) -> Result { + let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?; + let op = if self.track_op() { + Some(Op::Sum(self.clone(), sum_dims.to_vec())) + } else { + None + }; + let mut dims = self.dims().to_vec(); + for &sum_dim in sum_dims.iter() { + dims[sum_dim] = 1 + } + Ok(from_storage(storage, dims, op, false)) + } + pub fn matmul(&self, rhs: &Self) -> Result { let a_dims = self.shape().dims(); let b_dims = rhs.shape().dims();