mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Softmax tests + fix.
This commit is contained in:
@ -150,23 +150,22 @@ impl CpuStorage {
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
// [self] stores data in a contiguous way.
|
||||
let dims = shape.dims();
|
||||
let number_of_slices = dims[dim];
|
||||
let elem_per_slice = dims[dim];
|
||||
let prod_pre_dim = dims[..dim].iter().product();
|
||||
let prod_post_dim = dims[dim + 1..].iter().product();
|
||||
let elem_count = shape.elem_count();
|
||||
match self {
|
||||
Self::F32(storage) => {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
for post_idx in 0..prod_post_dim {
|
||||
let mut sum = 0f64;
|
||||
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
|
||||
while idx < elem_count {
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
sum += storage[idx] as f64;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
let sum = sum as f32;
|
||||
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
|
||||
while idx < elem_count {
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
storage[idx] /= sum;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
@ -177,13 +176,13 @@ impl CpuStorage {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
for post_idx in 0..prod_post_dim {
|
||||
let mut sum = 0f64;
|
||||
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
|
||||
while idx < elem_count {
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
sum += storage[idx];
|
||||
idx += prod_post_dim
|
||||
}
|
||||
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
|
||||
while idx < elem_count {
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
storage[idx] /= sum;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
|
@ -660,7 +660,7 @@ impl Tensor {
|
||||
}
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
self.storage
|
||||
.copy_strided_src(&mut storage, &shape, &self.stride, 0)?;
|
||||
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Reshape(self.clone()))
|
||||
} else {
|
||||
|
@ -69,3 +69,41 @@ fn tensor_2d_transpose() -> Result<()> {
|
||||
assert_eq!(((tensor + 1.)?.t()? - 1.)?.to_vec2::<f32>()?, data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax() -> Result<()> {
|
||||
let data = &[3f32, 1., 4., 1., 5., 9., 2., 1., 7., 8., 2., 8.];
|
||||
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||
let tensor = tensor.reshape((2, 2, 3))?;
|
||||
let t0 = tensor.log()?.softmax(0)?;
|
||||
let t1 = tensor.log()?.softmax(1)?;
|
||||
let t2 = tensor.log()?.softmax(2)?;
|
||||
assert_eq!(
|
||||
t0.to_vec3::<f32>()?,
|
||||
&[
|
||||
// 3/5, 1/2, 4/11
|
||||
[[0.6, 0.5, 0.36363637], [0.11111111, 0.71428573, 0.5294118]],
|
||||
// 2/5, 1/2, 7/11
|
||||
[[0.4, 0.5, 0.63636357], [0.8888889, 0.2857143, 0.47058824]]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t1.to_vec3::<f32>()?,
|
||||
&[
|
||||
// 3/4, 1/6, 4/13
|
||||
[[0.75, 0.16666667, 0.30769232], [0.25, 0.8333333, 0.6923077]],
|
||||
// 2/10, 1/3, 7/15
|
||||
[[0.2, 0.33333334, 0.46666664], [0.8, 0.6666667, 0.53333336]]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t2.to_vec3::<f32>()?,
|
||||
&[
|
||||
// (3, 1, 4) / 8, (1, 5, 9) / 15
|
||||
[[0.375, 0.125, 0.5], [0.06666667, 0.33333334, 0.6]],
|
||||
// (2, 1, 7) / 10, (8, 2, 8) / 18
|
||||
[[0.2, 0.1, 0.6999999], [0.44444445, 0.11111111, 0.44444445]]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user