From 4db972781f9cac46049bb3d80c29a21397926cc3 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 24 Jun 2023 08:24:06 +0100 Subject: [PATCH] Handle copying for the u32 type. --- src/cpu_backend.rs | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 6bae718a..21f7cd3a 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -88,6 +88,23 @@ fn binary_map T>( } } +fn copy_strided_src_( + src: &[T], + dst: &mut [T], + src_shape: &Shape, + src_stride: &[usize], + dst_offset: usize, +) { + if src_shape.is_contiguous(src_stride) { + dst[dst_offset..].copy_from_slice(src) + } else { + let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); + for (dst_index, src_index) in src_indexes.enumerate() { + dst[dst_index + dst_offset] = src[src_index] + } + } +} + impl CpuStorage { pub fn dtype(&self) -> DType { match self { @@ -280,25 +297,14 @@ impl CpuStorage { panic!("incoherent shape and strides {src_shape:?} {src_stride:?}") } match (self, dst) { + (Self::U32(src), Self::U32(dst)) => { + copy_strided_src_(src, dst, src_shape, src_stride, dst_offset) + } (Self::F32(src), Self::F32(dst)) => { - if src_shape.is_contiguous(src_stride) { - dst[dst_offset..].copy_from_slice(src) - } else { - let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); - for (dst_index, src_index) in src_indexes.enumerate() { - dst[dst_index + dst_offset] = src[src_index] - } - } + copy_strided_src_(src, dst, src_shape, src_stride, dst_offset) } (Self::F64(src), Self::F64(dst)) => { - if src_shape.is_contiguous(src_stride) { - dst[dst_offset..].copy_from_slice(src) - } else { - let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); - for (dst_index, src_index) in src_indexes.enumerate() { - dst[dst_index + dst_offset] = src[src_index] - } - } + copy_strided_src_(src, dst, src_shape, src_stride, dst_offset) } (_, dst) => { // This should be covered by the dtype check above.