Softmax tests + fix.

This commit is contained in:
laurent
2023-06-23 22:46:36 +01:00
parent d0a91db8fd
commit ae5dc5fbc6
3 changed files with 48 additions and 11 deletions

View File

@ -150,23 +150,22 @@ impl CpuStorage {
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
// [self] stores data in a contiguous way.
let dims = shape.dims();
let number_of_slices = dims[dim];
let elem_per_slice = dims[dim];
let prod_pre_dim = dims[..dim].iter().product();
let prod_post_dim = dims[dim + 1..].iter().product();
let elem_count = shape.elem_count();
match self {
Self::F32(storage) => {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
while idx < elem_count {
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += storage[idx] as f64;
idx += prod_post_dim
}
let sum = sum as f32;
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
while idx < elem_count {
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
storage[idx] /= sum;
idx += prod_post_dim
}
@ -177,13 +176,13 @@ impl CpuStorage {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
while idx < elem_count {
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += storage[idx];
idx += prod_post_dim
}
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
while idx < elem_count {
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
storage[idx] /= sum;
idx += prod_post_dim
}

View File

@ -660,7 +660,7 @@ impl Tensor {
}
let mut storage = self.device().zeros(&shape, self.dtype())?;
self.storage
.copy_strided_src(&mut storage, &shape, &self.stride, 0)?;
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
let op = if self.track_op() {
Some(Op::Reshape(self.clone()))
} else {

View File

@ -69,3 +69,41 @@ fn tensor_2d_transpose() -> Result<()> {
assert_eq!(((tensor + 1.)?.t()? - 1.)?.to_vec2::<f32>()?, data);
Ok(())
}
#[test]
fn softmax() -> Result<()> {
let data = &[3f32, 1., 4., 1., 5., 9., 2., 1., 7., 8., 2., 8.];
let tensor = Tensor::new(data, &Device::Cpu)?;
let tensor = tensor.reshape((2, 2, 3))?;
let t0 = tensor.log()?.softmax(0)?;
let t1 = tensor.log()?.softmax(1)?;
let t2 = tensor.log()?.softmax(2)?;
assert_eq!(
t0.to_vec3::<f32>()?,
&[
// 3/5, 1/2, 4/11
[[0.6, 0.5, 0.36363637], [0.11111111, 0.71428573, 0.5294118]],
// 2/5, 1/2, 7/11
[[0.4, 0.5, 0.63636357], [0.8888889, 0.2857143, 0.47058824]]
]
);
assert_eq!(
t1.to_vec3::<f32>()?,
&[
// 3/4, 1/6, 4/13
[[0.75, 0.16666667, 0.30769232], [0.25, 0.8333333, 0.6923077]],
// 2/10, 1/3, 7/15
[[0.2, 0.33333334, 0.46666664], [0.8, 0.6666667, 0.53333336]]
]
);
assert_eq!(
t2.to_vec3::<f32>()?,
&[
// (3, 1, 4) / 8, (1, 5, 9) / 15
[[0.375, 0.125, 0.5], [0.06666667, 0.33333334, 0.6]],
// (2, 1, 7) / 10, (8, 2, 8) / 18
[[0.2, 0.1, 0.6999999], [0.44444445, 0.11111111, 0.44444445]]
]
);
Ok(())
}