Helper function to build 3d arrays.

This commit is contained in:
laurent
2023-06-24 06:29:06 +01:00
parent ae5dc5fbc6
commit b4653e41be
2 changed files with 20 additions and 2 deletions

View File

@ -61,6 +61,25 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
}
}
impl<S: crate::WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
for &[[[S; N3]; N2]; N1]
{
fn shape(&self) -> Result<Shape> {
Ok(Shape::from((N1, N2, N3)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let mut vec = Vec::new();
vec.reserve(N1 * N2 * N3);
for i1 in 0..N1 {
for i2 in 0..N2 {
vec.extend(self[i1][i2])
}
}
S::to_cpu_storage_owned(vec)
}
}
impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))