Softmax tests + fix.

This commit is contained in:
laurent
2023-06-23 22:46:36 +01:00
parent d0a91db8fd
commit ae5dc5fbc6
3 changed files with 48 additions and 11 deletions

View File

@ -69,3 +69,41 @@ fn tensor_2d_transpose() -> Result<()> {
assert_eq!(((tensor + 1.)?.t()? - 1.)?.to_vec2::<f32>()?, data);
Ok(())
}
#[test]
fn softmax() -> Result<()> {
let data = &[3f32, 1., 4., 1., 5., 9., 2., 1., 7., 8., 2., 8.];
let tensor = Tensor::new(data, &Device::Cpu)?;
let tensor = tensor.reshape((2, 2, 3))?;
let t0 = tensor.log()?.softmax(0)?;
let t1 = tensor.log()?.softmax(1)?;
let t2 = tensor.log()?.softmax(2)?;
assert_eq!(
t0.to_vec3::<f32>()?,
&[
// 3/5, 1/2, 4/11
[[0.6, 0.5, 0.36363637], [0.11111111, 0.71428573, 0.5294118]],
// 2/5, 1/2, 7/11
[[0.4, 0.5, 0.63636357], [0.8888889, 0.2857143, 0.47058824]]
]
);
assert_eq!(
t1.to_vec3::<f32>()?,
&[
// 3/4, 1/6, 4/13
[[0.75, 0.16666667, 0.30769232], [0.25, 0.8333333, 0.6923077]],
// 2/10, 1/3, 7/15
[[0.2, 0.33333334, 0.46666664], [0.8, 0.6666667, 0.53333336]]
]
);
assert_eq!(
t2.to_vec3::<f32>()?,
&[
// (3, 1, 4) / 8, (1, 5, 9) / 15
[[0.375, 0.125, 0.5], [0.06666667, 0.33333334, 0.6]],
// (2, 1, 7) / 10, (8, 2, 8) / 18
[[0.2, 0.1, 0.6999999], [0.44444445, 0.11111111, 0.44444445]]
]
);
Ok(())
}