mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Decompose the softmax op so that it can be run on cuda.
This commit is contained in:
@ -102,6 +102,13 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cuda(&self) -> bool {
|
||||
match self {
|
||||
Self::Cpu => false,
|
||||
Self::Cuda(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
|
@ -395,18 +395,27 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let mut storage = self
|
||||
.storage
|
||||
.unary_impl::<crate::op::Exp>(shape, self.stride())?;
|
||||
// The resulting storage is contiguous.
|
||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Softmax(self.clone(), dim))
|
||||
// TODO: unify the two branches.
|
||||
if self.device().is_cuda() {
|
||||
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||
// the operation.
|
||||
let exp = self.exp()?;
|
||||
let sum_exp = exp.sum(&[dim])?;
|
||||
exp.broadcast_div(&sum_exp)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
let shape = self.shape();
|
||||
let mut storage = self
|
||||
.storage
|
||||
.unary_impl::<crate::op::Exp>(shape, self.stride())?;
|
||||
// The resulting storage is contiguous.
|
||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Softmax(self.clone(), dim))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
@ -436,14 +445,12 @@ impl Tensor {
|
||||
op: "matmul",
|
||||
});
|
||||
}
|
||||
if let crate::DeviceLocation::Cuda { .. } = self.device().location() {
|
||||
if !self.is_contiguous() || !rhs.is_contiguous() {
|
||||
// It looks like the cublas implementation of XgemmStridedBatched only supports
|
||||
// non-standard strides on the batch dimension.
|
||||
return Err(Error::RequiresContiguous {
|
||||
op: "matmul-cublas",
|
||||
});
|
||||
}
|
||||
if self.device().is_cuda() && (!self.is_contiguous() || !rhs.is_contiguous()) {
|
||||
// It looks like the cublas implementation of XgemmStridedBatched only supports
|
||||
// non-standard strides on the batch dimension.
|
||||
return Err(Error::RequiresContiguous {
|
||||
op: "matmul-cublas",
|
||||
});
|
||||
}
|
||||
|
||||
let m = a_dims[dim - 2];
|
||||
|
Reference in New Issue
Block a user