Cuda implementation for copying data around.

This commit is contained in:
laurent
2023-06-23 11:18:29 +01:00
parent 79e4b29c2f
commit 1a90f9d3a6
5 changed files with 56 additions and 2 deletions

View File

@ -381,4 +381,42 @@ impl CudaStorage {
let device = dev.clone();
Ok(Self { slice, device })
}
pub(crate) fn copy_strided_src(
&self,
dst: &mut Self,
src_shape: &Shape,
src_stride: &[usize],
dst_offset: usize,
) -> Result<()> {
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
let ds = dev.htod_copy([dims, src_stride].concat())?;
match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
let mut dst = dst.slice_mut(dst_offset..);
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?
}
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
let mut dst = dst.slice_mut(dst_offset..);
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
}
_ => {
return Err(CudaError::InternalError(
"dtype mismatch in copy_strided op",
))
}
}
Ok(())
}
}