Decompose the softmax op so that it can be run on cuda.

This commit is contained in:
laurent
2023-06-26 15:36:21 +01:00
parent 33c0234a33
commit 687c5beb6a
2 changed files with 33 additions and 19 deletions

View File

@ -102,6 +102,13 @@ impl Device {
}
}
pub fn is_cuda(&self) -> bool {
match self {
Self::Cpu => false,
Self::Cuda(_) => true,
}
}
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {

View File

@ -395,6 +395,14 @@ impl Tensor {
}
pub fn softmax(&self, dim: usize) -> Result<Self> {
// TODO: unify the two branches.
if self.device().is_cuda() {
// We do not have a cuda kernel for divide_by_sum_over_dim so split
// the operation.
let exp = self.exp()?;
let sum_exp = exp.sum(&[dim])?;
exp.broadcast_div(&sum_exp)
} else {
let shape = self.shape();
let mut storage = self
.storage
@ -408,6 +416,7 @@ impl Tensor {
};
Ok(from_storage(storage, shape.clone(), op, false))
}
}
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?;
@ -436,15 +445,13 @@ impl Tensor {
op: "matmul",
});
}
if let crate::DeviceLocation::Cuda { .. } = self.device().location() {
if !self.is_contiguous() || !rhs.is_contiguous() {
if self.device().is_cuda() && (!self.is_contiguous() || !rhs.is_contiguous()) {
// It looks like the cublas implementation of XgemmStridedBatched only supports
// non-standard strides on the batch dimension.
return Err(Error::RequiresContiguous {
op: "matmul-cublas",
});
}
}
let m = a_dims[dim - 2];
let k = a_dims[dim - 1];