mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Decompose the softmax op so that it can be run on cuda.
This commit is contained in:
@ -102,6 +102,13 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_cuda(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Cpu => false,
|
||||||
|
Self::Cuda(_) => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => {
|
Device::Cpu => {
|
||||||
|
@ -395,18 +395,27 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||||
let shape = self.shape();
|
// TODO: unify the two branches.
|
||||||
let mut storage = self
|
if self.device().is_cuda() {
|
||||||
.storage
|
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||||
.unary_impl::<crate::op::Exp>(shape, self.stride())?;
|
// the operation.
|
||||||
// The resulting storage is contiguous.
|
let exp = self.exp()?;
|
||||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
let sum_exp = exp.sum(&[dim])?;
|
||||||
let op = if self.track_op() {
|
exp.broadcast_div(&sum_exp)
|
||||||
Some(Op::Softmax(self.clone(), dim))
|
|
||||||
} else {
|
} else {
|
||||||
None
|
let shape = self.shape();
|
||||||
};
|
let mut storage = self
|
||||||
Ok(from_storage(storage, shape.clone(), op, false))
|
.storage
|
||||||
|
.unary_impl::<crate::op::Exp>(shape, self.stride())?;
|
||||||
|
// The resulting storage is contiguous.
|
||||||
|
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::Softmax(self.clone(), dim))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||||
@ -436,14 +445,12 @@ impl Tensor {
|
|||||||
op: "matmul",
|
op: "matmul",
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
if let crate::DeviceLocation::Cuda { .. } = self.device().location() {
|
if self.device().is_cuda() && (!self.is_contiguous() || !rhs.is_contiguous()) {
|
||||||
if !self.is_contiguous() || !rhs.is_contiguous() {
|
// It looks like the cublas implementation of XgemmStridedBatched only supports
|
||||||
// It looks like the cublas implementation of XgemmStridedBatched only supports
|
// non-standard strides on the batch dimension.
|
||||||
// non-standard strides on the batch dimension.
|
return Err(Error::RequiresContiguous {
|
||||||
return Err(Error::RequiresContiguous {
|
op: "matmul-cublas",
|
||||||
op: "matmul-cublas",
|
});
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let m = a_dims[dim - 2];
|
let m = a_dims[dim - 2];
|
||||||
|
Reference in New Issue
Block a user