Handle copying for the u32 type.

This commit is contained in:
laurent
2023-06-24 08:24:06 +01:00
parent dd657397b2
commit 4db972781f

View File

@ -88,6 +88,23 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
}
}
fn copy_strided_src_<T: Copy>(
src: &[T],
dst: &mut [T],
src_shape: &Shape,
src_stride: &[usize],
dst_offset: usize,
) {
if src_shape.is_contiguous(src_stride) {
dst[dst_offset..].copy_from_slice(src)
} else {
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
}
}
}
impl CpuStorage {
pub fn dtype(&self) -> DType {
match self {
@ -280,25 +297,14 @@ impl CpuStorage {
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
}
match (self, dst) {
(Self::U32(src), Self::U32(dst)) => {
copy_strided_src_(src, dst, src_shape, src_stride, dst_offset)
}
(Self::F32(src), Self::F32(dst)) => {
if src_shape.is_contiguous(src_stride) {
dst[dst_offset..].copy_from_slice(src)
} else {
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
}
}
copy_strided_src_(src, dst, src_shape, src_stride, dst_offset)
}
(Self::F64(src), Self::F64(dst)) => {
if src_shape.is_contiguous(src_stride) {
dst[dst_offset..].copy_from_slice(src)
} else {
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
}
}
copy_strided_src_(src, dst, src_shape, src_stride, dst_offset)
}
(_, dst) => {
// This should be covered by the dtype check above.