Actually copy the data around in cat (cpu only).

This commit is contained in:
laurent
2023-06-23 10:24:02 +01:00
parent 6110db31c9
commit 4712dcc2f6
3 changed files with 52 additions and 8 deletions

View File

@ -108,6 +108,38 @@ impl CpuStorage {
} }
} }
pub(crate) fn copy_strided_src(
&self,
dst: &mut Self,
src_shape: &Shape,
src_stride: &[usize],
dst_offset: usize,
) -> Result<()> {
// TODO: Optimize the contiguous case.
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
match (self, dst) {
(Self::F32(src), Self::F32(dst)) => {
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
}
}
(Self::F64(src), Self::F64(dst)) => {
for (dst_index, src_index) in src_indexes.enumerate() {
dst[dst_index + dst_offset] = src[src_index]
}
}
(_, dst) => {
// This should be covered by the dtype check above.
return Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: dst.dtype(),
op: "copy_strided",
});
}
}
Ok(())
}
pub(crate) fn matmul_impl( pub(crate) fn matmul_impl(
&self, &self,
rhs: &Self, rhs: &Self,

View File

@ -132,7 +132,7 @@ impl Storage {
self.same_device(rhs, "matmul")?; self.same_device(rhs, "matmul")?;
self.same_dtype(rhs, "matmul")?; self.same_dtype(rhs, "matmul")?;
match (self, rhs) { match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => { (Self::Cpu(lhs), Self::Cpu(rhs)) => {
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
Ok(Self::Cpu(storage)) Ok(Self::Cpu(storage))
} }
@ -151,11 +151,23 @@ impl Storage {
// self, the source can be strided whereas dst is contiguous. // self, the source can be strided whereas dst is contiguous.
pub(crate) fn copy_strided_src( pub(crate) fn copy_strided_src(
&self, &self,
_dst: &mut Self, dst: &mut Self,
_shape: &Shape, src_shape: &Shape,
_stride: &[usize], src_stride: &[usize],
_offset: usize, dst_offset: usize,
) { ) -> Result<()> {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => {
src.copy_strided_src(dst, src_shape, src_stride, dst_offset)
}
(Self::Cuda(_src), Self::Cuda(_dst)) => {
todo!() todo!()
} }
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "copy",
}),
}
}
} }

View File

@ -648,7 +648,7 @@ impl Tensor {
let mut storage = device.zeros(&shape, dtype)?; let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) { for (arg, &offset) in args.iter().zip(offsets.iter()) {
arg.storage arg.storage
.copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset) .copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)?
} }
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),