Decompose the softmax op so that it can be run on cuda.

This commit is contained in:
laurent
2023-06-26 15:36:21 +01:00
parent 33c0234a33
commit 687c5beb6a
2 changed files with 33 additions and 19 deletions

View File

@ -102,6 +102,13 @@ impl Device {
} }
} }
pub fn is_cuda(&self) -> bool {
match self {
Self::Cpu => false,
Self::Cuda(_) => true,
}
}
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> { pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self { match self {
Device::Cpu => { Device::Cpu => {

View File

@ -395,18 +395,27 @@ impl Tensor {
} }
pub fn softmax(&self, dim: usize) -> Result<Self> { pub fn softmax(&self, dim: usize) -> Result<Self> {
let shape = self.shape(); // TODO: unify the two branches.
let mut storage = self if self.device().is_cuda() {
.storage // We do not have a cuda kernel for divide_by_sum_over_dim so split
.unary_impl::<crate::op::Exp>(shape, self.stride())?; // the operation.
// The resulting storage is contiguous. let exp = self.exp()?;
storage.divide_by_sum_over_dim(shape, dim)?; let sum_exp = exp.sum(&[dim])?;
let op = if self.track_op() { exp.broadcast_div(&sum_exp)
Some(Op::Softmax(self.clone(), dim))
} else { } else {
None let shape = self.shape();
}; let mut storage = self
Ok(from_storage(storage, shape.clone(), op, false)) .storage
.unary_impl::<crate::op::Exp>(shape, self.stride())?;
// The resulting storage is contiguous.
storage.divide_by_sum_over_dim(shape, dim)?;
let op = if self.track_op() {
Some(Op::Softmax(self.clone(), dim))
} else {
None
};
Ok(from_storage(storage, shape.clone(), op, false))
}
} }
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> { pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
@ -436,14 +445,12 @@ impl Tensor {
op: "matmul", op: "matmul",
}); });
} }
if let crate::DeviceLocation::Cuda { .. } = self.device().location() { if self.device().is_cuda() && (!self.is_contiguous() || !rhs.is_contiguous()) {
if !self.is_contiguous() || !rhs.is_contiguous() { // It looks like the cublas implementation of XgemmStridedBatched only supports
// It looks like the cublas implementation of XgemmStridedBatched only supports // non-standard strides on the batch dimension.
// non-standard strides on the batch dimension. return Err(Error::RequiresContiguous {
return Err(Error::RequiresContiguous { op: "matmul-cublas",
op: "matmul-cublas", });
});
}
} }
let m = a_dims[dim - 2]; let m = a_dims[dim - 2];