mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Optimize for the contiguous case.
This commit is contained in:
@ -396,20 +396,29 @@ impl CudaStorage {
|
|||||||
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
||||||
match (&self.slice, &mut dst.slice) {
|
match (&self.slice, &mut dst.slice) {
|
||||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||||
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
|
|
||||||
let mut dst = dst.slice_mut(dst_offset..);
|
let mut dst = dst.slice_mut(dst_offset..);
|
||||||
// SAFETY: Set later by running the kernel.
|
if src_shape.is_contiguous(src_stride) {
|
||||||
let params = (el_count, dims.len(), &ds, src, &mut dst);
|
dev.dtod_copy(src, &mut dst)?
|
||||||
// SAFETY: ffi.
|
} else {
|
||||||
unsafe { func.launch(cfg, params) }?
|
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let params = (el_count, dims.len(), &ds, src, &mut dst);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?
|
||||||
|
}
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||||
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
|
||||||
let mut dst = dst.slice_mut(dst_offset..);
|
let mut dst = dst.slice_mut(dst_offset..);
|
||||||
// SAFETY: Set later by running the kernel.
|
if src_shape.is_contiguous(src_stride) {
|
||||||
let params = (el_count, dims.len(), &ds, src, &mut dst);
|
dev.dtod_copy(src, &mut dst)?
|
||||||
// SAFETY: ffi.
|
} else {
|
||||||
unsafe { func.launch(cfg, params) }?;
|
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
||||||
|
let mut dst = dst.slice_mut(dst_offset..);
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let params = (el_count, dims.len(), &ds, src, &mut dst);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
return Err(CudaError::InternalError(
|
return Err(CudaError::InternalError(
|
||||||
|
Reference in New Issue
Block a user