Handle the contiguous case in an optimized way when copying cpu memory.

This commit is contained in:
laurent
2023-06-23 12:20:16 +01:00
parent d4054ab500
commit 52c503ba8f

View File

@ -115,17 +115,25 @@ impl CpuStorage {
src_stride: &[usize],
dst_offset: usize,
) -> Result<()> {
// TODO: Optimize the contiguous case.
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
match (self, dst) {
(Self::F32(src), Self::F32(dst)) => {
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
if src_shape.is_contiguous(src_stride) {
dst[dst_offset..].copy_from_slice(src)
} else {
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
}
}
}
(Self::F64(src), Self::F64(dst)) => {
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
if src_shape.is_contiguous(src_stride) {
dst[dst_offset..].copy_from_slice(src)
} else {
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
}
}
}
(_, dst) => {