mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Cuda implementation for copying data around.
This commit is contained in:
@ -381,4 +381,42 @@ impl CudaStorage {
|
||||
let device = dev.clone();
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn copy_strided_src(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
src_shape: &Shape,
|
||||
src_stride: &[usize],
|
||||
dst_offset: usize,
|
||||
) -> Result<()> {
|
||||
let dims = src_shape.dims();
|
||||
let el_count = src_shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
||||
match (&self.slice, &mut dst.slice) {
|
||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?
|
||||
}
|
||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
}
|
||||
_ => {
|
||||
return Err(CudaError::InternalError(
|
||||
"dtype mismatch in copy_strided op",
|
||||
))
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user