Optimize for the contiguous case.

This commit is contained in:
laurent
2023-06-23 11:23:49 +01:00
parent 1a90f9d3a6
commit 4ffdeb4e23

View File

@ -396,20 +396,29 @@ impl CudaStorage {
let ds = dev.htod_copy([dims, src_stride].concat())?; let ds = dev.htod_copy([dims, src_stride].concat())?;
match (&self.slice, &mut dst.slice) { match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
let mut dst = dst.slice_mut(dst_offset..); let mut dst = dst.slice_mut(dst_offset..);
// SAFETY: Set later by running the kernel. if src_shape.is_contiguous(src_stride) {
let params = (el_count, dims.len(), &ds, src, &mut dst); dev.dtod_copy(src, &mut dst)?
// SAFETY: ffi. } else {
unsafe { func.launch(cfg, params) }? let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?
}
} }
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
let mut dst = dst.slice_mut(dst_offset..); let mut dst = dst.slice_mut(dst_offset..);
// SAFETY: Set later by running the kernel. if src_shape.is_contiguous(src_stride) {
let params = (el_count, dims.len(), &ds, src, &mut dst); dev.dtod_copy(src, &mut dst)?
// SAFETY: ffi. } else {
unsafe { func.launch(cfg, params) }?; let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
let mut dst = dst.slice_mut(dst_offset..);
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
}
} }
_ => { _ => {
return Err(CudaError::InternalError( return Err(CudaError::InternalError(