diff --git a/src/device.rs b/src/device.rs index 62afd905..e83c844f 100644 --- a/src/device.rs +++ b/src/device.rs @@ -61,6 +61,25 @@ impl NdArray for &[[S; N]; } } +impl NdArray + for &[[[S; N3]; N2]; N1] +{ + fn shape(&self) -> Result { + Ok(Shape::from((N1, N2, N3))) + } + + fn to_cpu_storage(&self) -> CpuStorage { + let mut vec = Vec::new(); + vec.reserve(N1 * N2 * N3); + for i1 in 0..N1 { + for i2 in 0..N2 { + vec.extend(self[i1][i2]) + } + } + S::to_cpu_storage_owned(vec) + } +} + impl Device { pub fn new_cuda(ordinal: usize) -> Result { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 96df88f5..0ffcad62 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -72,9 +72,8 @@ fn tensor_2d_transpose() -> Result<()> { #[test] fn softmax() -> Result<()> { - let data = &[3f32, 1., 4., 1., 5., 9., 2., 1., 7., 8., 2., 8.]; + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; let tensor = Tensor::new(data, &Device::Cpu)?; - let tensor = tensor.reshape((2, 2, 3))?; let t0 = tensor.log()?.softmax(0)?; let t1 = tensor.log()?.softmax(1)?; let t2 = tensor.log()?.softmax(2)?;