From 8443963d4f0d3f96a375fbacc9010f6c1dc6a240 Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 23 Jun 2023 22:00:13 +0100 Subject: [PATCH] Skeleton implementation for softmax. --- src/cpu_backend.rs | 4 ++++ src/cuda_backend.rs | 4 ++++ src/dummy_cuda_backend.rs | 2 ++ src/storage.rs | 7 +++++++ src/tensor.rs | 15 +++++++++++++++ 5 files changed, 32 insertions(+) diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 1b618551..fecbe643 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -147,6 +147,10 @@ impl CpuStorage { } } + pub(crate) fn divide_by_sum_over_dim(&mut self, _shape: &Shape, _dim: usize) { + todo!() + } + pub(crate) fn affine_impl( &self, shape: &Shape, diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 345b84d8..f2a2ce43 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -291,6 +291,10 @@ impl CudaStorage { Ok(Self { slice, device }) } + pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) { + todo!() + } + pub(crate) fn unary_impl( &self, shape: &Shape, diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index f2e0d36c..fb50f8f2 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -62,6 +62,8 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) {} + pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/src/storage.rs b/src/storage.rs index ebbdcbb2..c5544478 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -72,6 +72,13 @@ impl Storage { } } + pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) { + match self { + Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim), + Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim), + } + } + pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { match self { Storage::Cpu(storage) => { diff --git a/src/tensor.rs b/src/tensor.rs index 264df5f6..82de7e17 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -295,6 +295,21 @@ impl Tensor { Ok(from_storage(storage, shape.clone(), op, false)) } + pub fn softmax(&self, dim: usize) -> Result { + let shape = self.shape(); + let mut storage = self + .storage + .unary_impl::(shape, self.stride())?; + // The resulting storage is contiguous. + storage.divide_by_sum_over_dim(shape, dim); + let op = if self.track_op() { + Some(Op::Softmax(self.clone(), dim)) + } else { + None + }; + Ok(from_storage(storage, shape.clone(), op, false)) + } + pub fn matmul(&self, rhs: &Self) -> Result { let a_dims = self.shape().dims(); let b_dims = rhs.shape().dims();