diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 91f0472b..9cbf82be 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -240,6 +240,24 @@ enum CudaStorageSlice { F64(CudaSlice), } +fn slice_src_and_dst<'a, T>( + src: &'a CudaSlice, + src_offset: usize, + dst: &'a mut CudaSlice, + dst_offset: usize, +) -> ( + cudarc::driver::CudaView<'a, T>, + cudarc::driver::CudaViewMut<'a, T>, +) { + let to_copy = dst + .len() + .saturating_sub(dst_offset) + .min(src.len().saturating_sub(src_offset)); + let src = src.slice(src_offset..src_offset + to_copy); + let dst = dst.slice_mut(dst_offset..dst_offset + to_copy); + (src, dst) +} + #[derive(Debug)] pub struct CudaStorage { slice: CudaStorageSlice, @@ -903,10 +921,7 @@ impl CudaStorage { let ds = dev.htod_copy([dims, src_stride].concat())?; match (&self.slice, &mut dst.slice) { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { - let src = src.slice(src_offset..); - let elem_to_copy = (dst.len() - dst_offset).min(src.len()); - let src = src.slice(..elem_to_copy); - let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy); + let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); if src_shape.is_contiguous(src_stride) { dev.dtod_copy(&src, &mut dst)? } else { @@ -918,10 +933,7 @@ impl CudaStorage { } } (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { - let src = src.slice(src_offset..); - let elem_to_copy = (dst.len() - dst_offset).min(src.len()); - let src = src.slice(..elem_to_copy); - let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy); + let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); if src_shape.is_contiguous(src_stride) { dev.dtod_copy(&src, &mut dst)? } else { @@ -933,10 +945,7 @@ impl CudaStorage { } } (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { - let src = src.slice(src_offset..); - let elem_to_copy = (dst.len() - dst_offset).min(src.len()); - let src = src.slice(..elem_to_copy); - let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy); + let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); if src_shape.is_contiguous(src_stride) { dev.dtod_copy(&src, &mut dst)? } else { @@ -948,10 +957,7 @@ impl CudaStorage { } } (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { - let src = src.slice(src_offset..); - let elem_to_copy = (dst.len() - dst_offset).min(src.len()); - let src = src.slice(..elem_to_copy); - let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy); + let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); if src_shape.is_contiguous(src_stride) { dev.dtod_copy(&src, &mut dst)? } else { @@ -963,9 +969,7 @@ impl CudaStorage { } } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { - let src = src.slice(src_offset..); - let elem_to_copy = (dst.len() - dst_offset).min(src.len()); - let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy); + let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); if src_shape.is_contiguous(src_stride) { dev.dtod_copy(&src, &mut dst)? } else {