mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Factor the slicing code in cuda.
This commit is contained in:
@ -240,6 +240,24 @@ enum CudaStorageSlice {
|
|||||||
F64(CudaSlice<f64>),
|
F64(CudaSlice<f64>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn slice_src_and_dst<'a, T>(
|
||||||
|
src: &'a CudaSlice<T>,
|
||||||
|
src_offset: usize,
|
||||||
|
dst: &'a mut CudaSlice<T>,
|
||||||
|
dst_offset: usize,
|
||||||
|
) -> (
|
||||||
|
cudarc::driver::CudaView<'a, T>,
|
||||||
|
cudarc::driver::CudaViewMut<'a, T>,
|
||||||
|
) {
|
||||||
|
let to_copy = dst
|
||||||
|
.len()
|
||||||
|
.saturating_sub(dst_offset)
|
||||||
|
.min(src.len().saturating_sub(src_offset));
|
||||||
|
let src = src.slice(src_offset..src_offset + to_copy);
|
||||||
|
let dst = dst.slice_mut(dst_offset..dst_offset + to_copy);
|
||||||
|
(src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct CudaStorage {
|
pub struct CudaStorage {
|
||||||
slice: CudaStorageSlice,
|
slice: CudaStorageSlice,
|
||||||
@ -903,10 +921,7 @@ impl CudaStorage {
|
|||||||
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
||||||
match (&self.slice, &mut dst.slice) {
|
match (&self.slice, &mut dst.slice) {
|
||||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||||
let src = src.slice(src_offset..);
|
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
|
||||||
let src = src.slice(..elem_to_copy);
|
|
||||||
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
|
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_shape.is_contiguous(src_stride) {
|
||||||
dev.dtod_copy(&src, &mut dst)?
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
} else {
|
} else {
|
||||||
@ -918,10 +933,7 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
||||||
let src = src.slice(src_offset..);
|
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
|
||||||
let src = src.slice(..elem_to_copy);
|
|
||||||
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
|
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_shape.is_contiguous(src_stride) {
|
||||||
dev.dtod_copy(&src, &mut dst)?
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
} else {
|
} else {
|
||||||
@ -933,10 +945,7 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||||
let src = src.slice(src_offset..);
|
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
|
||||||
let src = src.slice(..elem_to_copy);
|
|
||||||
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
|
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_shape.is_contiguous(src_stride) {
|
||||||
dev.dtod_copy(&src, &mut dst)?
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
} else {
|
} else {
|
||||||
@ -948,10 +957,7 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
||||||
let src = src.slice(src_offset..);
|
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
|
||||||
let src = src.slice(..elem_to_copy);
|
|
||||||
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
|
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_shape.is_contiguous(src_stride) {
|
||||||
dev.dtod_copy(&src, &mut dst)?
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
} else {
|
} else {
|
||||||
@ -963,9 +969,7 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||||
let src = src.slice(src_offset..);
|
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
|
||||||
let mut dst = dst.slice_mut(dst_offset..dst_offset + elem_to_copy);
|
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_shape.is_contiguous(src_stride) {
|
||||||
dev.dtod_copy(&src, &mut dst)?
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
} else {
|
} else {
|
||||||
|
Reference in New Issue
Block a user