Actually copy the data around in cat (cpu only).

This commit is contained in:
laurent
2023-06-23 10:24:02 +01:00
parent 6110db31c9
commit 4712dcc2f6
3 changed files with 52 additions and 8 deletions

View File

@ -108,6 +108,38 @@ impl CpuStorage {
}
}
pub(crate) fn copy_strided_src(
&self,
dst: &mut Self,
src_shape: &Shape,
src_stride: &[usize],
dst_offset: usize,
) -> Result<()> {
// TODO: Optimize the contiguous case.
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
match (self, dst) {
(Self::F32(src), Self::F32(dst)) => {
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
}
}
(Self::F64(src), Self::F64(dst)) => {
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
}
}
(_, dst) => {
// This should be covered by the dtype check above.
return Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: dst.dtype(),
op: "copy_strided",
});
}
}
Ok(())
}
pub(crate) fn matmul_impl(
&self,
rhs: &Self,