diff --git a/src/device.rs b/src/device.rs index 1b56c178..3562d374 100644 --- a/src/device.rs +++ b/src/device.rs @@ -102,6 +102,13 @@ impl Device { } } + pub fn is_cuda(&self) -> bool { + match self { + Self::Cpu => false, + Self::Cuda(_) => true, + } + } + pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result { match self { Device::Cpu => { diff --git a/src/tensor.rs b/src/tensor.rs index 82f80f1a..6a47ef4c 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -395,18 +395,27 @@ impl Tensor { } pub fn softmax(&self, dim: usize) -> Result { - let shape = self.shape(); - let mut storage = self - .storage - .unary_impl::(shape, self.stride())?; - // The resulting storage is contiguous. - storage.divide_by_sum_over_dim(shape, dim)?; - let op = if self.track_op() { - Some(Op::Softmax(self.clone(), dim)) + // TODO: unify the two branches. + if self.device().is_cuda() { + // We do not have a cuda kernel for divide_by_sum_over_dim so split + // the operation. + let exp = self.exp()?; + let sum_exp = exp.sum(&[dim])?; + exp.broadcast_div(&sum_exp) } else { - None - }; - Ok(from_storage(storage, shape.clone(), op, false)) + let shape = self.shape(); + let mut storage = self + .storage + .unary_impl::(shape, self.stride())?; + // The resulting storage is contiguous. + storage.divide_by_sum_over_dim(shape, dim)?; + let op = if self.track_op() { + Some(Op::Softmax(self.clone(), dim)) + } else { + None + }; + Ok(from_storage(storage, shape.clone(), op, false)) + } } pub fn sum(&self, sum_dims: &[usize]) -> Result { @@ -436,14 +445,12 @@ impl Tensor { op: "matmul", }); } - if let crate::DeviceLocation::Cuda { .. } = self.device().location() { - if !self.is_contiguous() || !rhs.is_contiguous() { - // It looks like the cublas implementation of XgemmStridedBatched only supports - // non-standard strides on the batch dimension. - return Err(Error::RequiresContiguous { - op: "matmul-cublas", - }); - } + if self.device().is_cuda() && (!self.is_contiguous() || !rhs.is_contiguous()) { + // It looks like the cublas implementation of XgemmStridedBatched only supports + // non-standard strides on the batch dimension. + return Err(Error::RequiresContiguous { + op: "matmul-cublas", + }); } let m = a_dims[dim - 2];