Factor the slicing code in cuda.

This commit is contained in:
laurent
2023-06-27 15:45:59 +01:00
parent 07a682c2ff
commit e221d38819

View File

@ -240,6 +240,24 @@ enum CudaStorageSlice {
F64(CudaSlice<f64>),
}
fn slice_src_and_dst<'a, T>(
src: &'a CudaSlice<T>,
src_offset: usize,
dst: &'a mut CudaSlice<T>,
dst_offset: usize,
) -> (
cudarc::driver::CudaView<'a, T>,
cudarc::driver::CudaViewMut<'a, T>,
) {
let to_copy = dst
.len()
.saturating_sub(dst_offset)
.min(src.len().saturating_sub(src_offset));
let src = src.slice(src_offset..src_offset + to_copy);
let dst = dst.slice_mut(dst_offset..dst_offset + to_copy);
(src, dst)
}
#[derive(Debug)]
pub struct CudaStorage {
slice: CudaStorageSlice,
@ -903,10 +921,7 @@ impl CudaStorage {
let ds = dev.htod_copy([dims, src_stride].concat())?;
match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
let src = src.slice(src_offset..);
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
let src = src.slice(..elem_to_copy);
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
@ -918,10 +933,7 @@ impl CudaStorage {
}
}
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
let src = src.slice(src_offset..);
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
let src = src.slice(..elem_to_copy);
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
@ -933,10 +945,7 @@ impl CudaStorage {
}
}
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
let src = src.slice(src_offset..);
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
let src = src.slice(..elem_to_copy);
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
@ -948,10 +957,7 @@ impl CudaStorage {
}
}
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
let src = src.slice(src_offset..);
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
let src = src.slice(..elem_to_copy);
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
@ -963,9 +969,7 @@ impl CudaStorage {
}
}
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let src = src.slice(src_offset..);
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {