mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Decompose the softmax op so that it can be run on cuda.
This commit is contained in:
@ -102,6 +102,13 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_cuda(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Cpu => false,
|
||||||
|
Self::Cuda(_) => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => {
|
Device::Cpu => {
|
||||||
|
@ -395,6 +395,14 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||||
|
// TODO: unify the two branches.
|
||||||
|
if self.device().is_cuda() {
|
||||||
|
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||||
|
// the operation.
|
||||||
|
let exp = self.exp()?;
|
||||||
|
let sum_exp = exp.sum(&[dim])?;
|
||||||
|
exp.broadcast_div(&sum_exp)
|
||||||
|
} else {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let mut storage = self
|
let mut storage = self
|
||||||
.storage
|
.storage
|
||||||
@ -408,6 +416,7 @@ impl Tensor {
|
|||||||
};
|
};
|
||||||
Ok(from_storage(storage, shape.clone(), op, false))
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||||
let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?;
|
let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?;
|
||||||
@ -436,15 +445,13 @@ impl Tensor {
|
|||||||
op: "matmul",
|
op: "matmul",
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
if let crate::DeviceLocation::Cuda { .. } = self.device().location() {
|
if self.device().is_cuda() && (!self.is_contiguous() || !rhs.is_contiguous()) {
|
||||||
if !self.is_contiguous() || !rhs.is_contiguous() {
|
|
||||||
// It looks like the cublas implementation of XgemmStridedBatched only supports
|
// It looks like the cublas implementation of XgemmStridedBatched only supports
|
||||||
// non-standard strides on the batch dimension.
|
// non-standard strides on the batch dimension.
|
||||||
return Err(Error::RequiresContiguous {
|
return Err(Error::RequiresContiguous {
|
||||||
op: "matmul-cublas",
|
op: "matmul-cublas",
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
let m = a_dims[dim - 2];
|
let m = a_dims[dim - 2];
|
||||||
let k = a_dims[dim - 1];
|
let k = a_dims[dim - 1];
|
||||||
|
Reference in New Issue
Block a user