From 4712dcc2f670d8bc0ee643fd170f2ccc5eb45f3f Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 23 Jun 2023 10:24:02 +0100 Subject: [PATCH] Actually copy the data around in cat (cpu only). --- src/cpu_backend.rs | 32 ++++++++++++++++++++++++++++++++ src/storage.rs | 26 +++++++++++++++++++------- src/tensor.rs | 2 +- 3 files changed, 52 insertions(+), 8 deletions(-) diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 1bd272d8..bf372589 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -108,6 +108,38 @@ impl CpuStorage { } } + pub(crate) fn copy_strided_src( + &self, + dst: &mut Self, + src_shape: &Shape, + src_stride: &[usize], + dst_offset: usize, + ) -> Result<()> { + // TODO: Optimize the contiguous case. + let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); + match (self, dst) { + (Self::F32(src), Self::F32(dst)) => { + for (dst_index, src_index) in src_indexes.enumerate() { + dst[dst_index + dst_offset] = src[src_index] + } + } + (Self::F64(src), Self::F64(dst)) => { + for (dst_index, src_index) in src_indexes.enumerate() { + dst[dst_index + dst_offset] = src[src_index] + } + } + (_, dst) => { + // This should be covered by the dtype check above. + return Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: dst.dtype(), + op: "copy_strided", + }); + } + } + Ok(()) + } + pub(crate) fn matmul_impl( &self, rhs: &Self, diff --git a/src/storage.rs b/src/storage.rs index 55934064..a1f5d121 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -132,7 +132,7 @@ impl Storage { self.same_device(rhs, "matmul")?; self.same_dtype(rhs, "matmul")?; match (self, rhs) { - (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { + (Self::Cpu(lhs), Self::Cpu(rhs)) => { let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; Ok(Self::Cpu(storage)) } @@ -151,11 +151,23 @@ impl Storage { // self, the source can be strided whereas dst is contiguous. pub(crate) fn copy_strided_src( &self, - _dst: &mut Self, - _shape: &Shape, - _stride: &[usize], - _offset: usize, - ) { - todo!() + dst: &mut Self, + src_shape: &Shape, + src_stride: &[usize], + dst_offset: usize, + ) -> Result<()> { + match (self, dst) { + (Self::Cpu(src), Self::Cpu(dst)) => { + src.copy_strided_src(dst, src_shape, src_stride, dst_offset) + } + (Self::Cuda(_src), Self::Cuda(_dst)) => { + todo!() + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "copy", + }), + } } } diff --git a/src/tensor.rs b/src/tensor.rs index 8467f099..5872a244 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -648,7 +648,7 @@ impl Tensor { let mut storage = device.zeros(&shape, dtype)?; for (arg, &offset) in args.iter().zip(offsets.iter()) { arg.storage - .copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset) + .copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)? } let tensor_ = Tensor_ { id: TensorId::new(),