Helper function to build 3d arrays.

This commit is contained in:
laurent
2023-06-24 06:29:06 +01:00
parent ae5dc5fbc6
commit b4653e41be
2 changed files with 20 additions and 2 deletions

View File

@ -72,9 +72,8 @@ fn tensor_2d_transpose() -> Result<()> {
#[test]
fn softmax() -> Result<()> {
let data = &[3f32, 1., 4., 1., 5., 9., 2., 1., 7., 8., 2., 8.];
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, &Device::Cpu)?;
let tensor = tensor.reshape((2, 2, 3))?;
let t0 = tensor.log()?.softmax(0)?;
let t1 = tensor.log()?.softmax(1)?;
let t2 = tensor.log()?.softmax(2)?;