diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index bf372589..f4055173 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -115,17 +115,25 @@ impl CpuStorage { src_stride: &[usize], dst_offset: usize, ) -> Result<()> { - // TODO: Optimize the contiguous case. - let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); match (self, dst) { (Self::F32(src), Self::F32(dst)) => { - for (dst_index, src_index) in src_indexes.enumerate() { - dst[dst_index + dst_offset] = src[src_index] + if src_shape.is_contiguous(src_stride) { + dst[dst_offset..].copy_from_slice(src) + } else { + let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); + for (dst_index, src_index) in src_indexes.enumerate() { + dst[dst_index + dst_offset] = src[src_index] + } } } (Self::F64(src), Self::F64(dst)) => { - for (dst_index, src_index) in src_indexes.enumerate() { - dst[dst_index + dst_offset] = src[src_index] + if src_shape.is_contiguous(src_stride) { + dst[dst_offset..].copy_from_slice(src) + } else { + let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); + for (dst_index, src_index) in src_indexes.enumerate() { + dst[dst_index + dst_offset] = src[src_index] + } } } (_, dst) => {