diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 21f7cd3a..3e4e1826 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -91,16 +91,23 @@ fn binary_map T>( fn copy_strided_src_( src: &[T], dst: &mut [T], + dst_offset: usize, src_shape: &Shape, src_stride: &[usize], - dst_offset: usize, + src_offset: usize, ) { + let src = &src[src_offset..]; if src_shape.is_contiguous(src_stride) { - dst[dst_offset..].copy_from_slice(src) + let elem_to_copy = dst.len() - dst_offset; + dst[dst_offset..].copy_from_slice(&src[..elem_to_copy]) } else { let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); for (dst_index, src_index) in src_indexes.enumerate() { - dst[dst_index + dst_offset] = src[src_index] + let dst_index = dst_index + dst_offset; + if dst_index >= dst.len() { + break; + } + dst[dst_index] = src[src_index] } } } @@ -289,22 +296,23 @@ impl CpuStorage { pub(crate) fn copy_strided_src( &self, dst: &mut Self, + dst_offset: usize, src_shape: &Shape, src_stride: &[usize], - dst_offset: usize, + src_offset: usize, ) -> Result<()> { if src_shape.rank() != src_stride.len() { panic!("incoherent shape and strides {src_shape:?} {src_stride:?}") } match (self, dst) { (Self::U32(src), Self::U32(dst)) => { - copy_strided_src_(src, dst, src_shape, src_stride, dst_offset) + copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) } (Self::F32(src), Self::F32(dst)) => { - copy_strided_src_(src, dst, src_shape, src_stride, dst_offset) + copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) } (Self::F64(src), Self::F64(dst)) => { - copy_strided_src_(src, dst, src_shape, src_stride, dst_offset) + copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) } (_, dst) => { // This should be covered by the dtype check above. diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 90437ed0..70084821 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -450,9 +450,10 @@ impl CudaStorage { pub(crate) fn copy_strided_src( &self, dst: &mut Self, + dst_offset: usize, src_shape: &Shape, src_stride: &[usize], - dst_offset: usize, + src_offset: usize, ) -> Result<()> { if src_shape.rank() != src_stride.len() { panic!("incoherent shape and strides {src_shape:?} {src_stride:?}") @@ -464,26 +465,27 @@ impl CudaStorage { let ds = dev.htod_copy([dims, src_stride].concat())?; match (&self.slice, &mut dst.slice) { (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { + let src = src.slice(src_offset..); let mut dst = dst.slice_mut(dst_offset..); if src_shape.is_contiguous(src_stride) { - dev.dtod_copy(src, &mut dst)? + dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?; // SAFETY: Set later by running the kernel. - let params = (el_count, dims.len(), &ds, src, &mut dst); + let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }? } } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { + let src = src.slice(src_offset..); let mut dst = dst.slice_mut(dst_offset..); if src_shape.is_contiguous(src_stride) { - dev.dtod_copy(src, &mut dst)? + dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?; - let mut dst = dst.slice_mut(dst_offset..); // SAFETY: Set later by running the kernel. - let params = (el_count, dims.len(), &ds, src, &mut dst); + let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }?; } diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index b3199f65..972ec008 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -101,6 +101,7 @@ impl CudaStorage { pub(crate) fn copy_strided_src( &self, _: &mut Self, + _: usize, _: &Shape, _: &[usize], _: usize, diff --git a/src/storage.rs b/src/storage.rs index 38a48b72..21992992 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -196,16 +196,17 @@ impl Storage { pub(crate) fn copy_strided_src( &self, dst: &mut Self, + dst_offset: usize, src_shape: &Shape, src_stride: &[usize], - dst_offset: usize, + src_offset: usize, ) -> Result<()> { match (self, dst) { (Self::Cpu(src), Self::Cpu(dst)) => { - src.copy_strided_src(dst, src_shape, src_stride, dst_offset) + src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset) } (Self::Cuda(src), Self::Cuda(dst)) => { - Ok(src.copy_strided_src(dst, src_shape, src_stride, dst_offset)?) + Ok(src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)?) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), diff --git a/src/tensor.rs b/src/tensor.rs index 5467584b..b58ed739 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -317,9 +317,10 @@ impl Tensor { let mut dims = dims.to_vec(); dims[dim] = length; let shape = Shape::from(dims); - let storage = self.device().zeros(&shape, self.dtype())?; - // TODO: Actually copy the data, compared to copy_strided_src this requires a src start - // offset as well as a way to specify the number of elements to be copied. + let mut storage = self.device().zeros(&shape, self.dtype())?; + let src_offset = 0; // TODO + self.storage + .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, src_offset)?; let op = if self.track_op() { Some(Op::Narrow(self.clone(), dim, start, length)) } else { @@ -666,7 +667,7 @@ impl Tensor { let shape = self.shape(); let mut storage = self.device().zeros(shape, self.dtype())?; self.storage - .copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?; + .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?; Ok(from_storage( storage, shape.clone(), @@ -709,7 +710,7 @@ impl Tensor { } else { let mut storage = self.device().zeros(&shape, self.dtype())?; self.storage - .copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?; + .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?; Ok(from_storage(storage, shape, op, false)) } } @@ -786,7 +787,7 @@ impl Tensor { let mut storage = device.zeros(&shape, dtype)?; for (arg, &offset) in args.iter().zip(offsets.iter()) { arg.storage - .copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)? + .copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)? } Ok(from_storage(storage, shape, op, false)) }