mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Softmax cpu implementation.
This commit is contained in:
@ -147,8 +147,52 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, _shape: &Shape, _dim: usize) {
|
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||||
todo!()
|
// [self] stores data in a contiguous way.
|
||||||
|
let dims = shape.dims();
|
||||||
|
let number_of_slices = dims[dim];
|
||||||
|
let prod_pre_dim = dims[..dim].iter().product();
|
||||||
|
let prod_post_dim = dims[dim + 1..].iter().product();
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
match self {
|
||||||
|
Self::F32(storage) => {
|
||||||
|
for pre_idx in 0..prod_pre_dim {
|
||||||
|
for post_idx in 0..prod_post_dim {
|
||||||
|
let mut sum = 0f64;
|
||||||
|
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
|
||||||
|
while idx < elem_count {
|
||||||
|
sum += storage[idx] as f64;
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
let sum = sum as f32;
|
||||||
|
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
|
||||||
|
while idx < elem_count {
|
||||||
|
storage[idx] /= sum;
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Self::F64(storage) => {
|
||||||
|
for pre_idx in 0..prod_pre_dim {
|
||||||
|
for post_idx in 0..prod_post_dim {
|
||||||
|
let mut sum = 0f64;
|
||||||
|
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
|
||||||
|
while idx < elem_count {
|
||||||
|
sum += storage[idx];
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
|
||||||
|
while idx < elem_count {
|
||||||
|
storage[idx] /= sum;
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Self::U32(_) => {}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn affine_impl(
|
pub(crate) fn affine_impl(
|
||||||
|
@ -291,7 +291,7 @@ impl CudaStorage {
|
|||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) {
|
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,7 +62,9 @@ impl CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) {}
|
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
|
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
@ -72,11 +72,12 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) {
|
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim),
|
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||||
Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim),
|
Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||||
|
@ -301,7 +301,7 @@ impl Tensor {
|
|||||||
.storage
|
.storage
|
||||||
.unary_impl::<crate::op::Exp>(shape, self.stride())?;
|
.unary_impl::<crate::op::Exp>(shape, self.stride())?;
|
||||||
// The resulting storage is contiguous.
|
// The resulting storage is contiguous.
|
||||||
storage.divide_by_sum_over_dim(shape, dim);
|
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Softmax(self.clone(), dim))
|
Some(Op::Softmax(self.clone(), dim))
|
||||||
} else {
|
} else {
|
||||||
|
Reference in New Issue
Block a user