mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Stride support for unary ops.
This commit is contained in:
@ -167,19 +167,17 @@ impl CudaStorage {
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
if !shape.is_contiguous(stride) {
|
||||
return Err(CudaError::RequiresContiguous { op: "affine" });
|
||||
}
|
||||
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.0.htod_copy([dims, stride].concat())?;
|
||||
match self {
|
||||
Self::F32(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
|
||||
let params = (elem_count, arg, &out);
|
||||
let out = unsafe { dev.0.alloc::<f32>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(Self::F32(out))
|
||||
@ -187,8 +185,8 @@ impl CudaStorage {
|
||||
Self::F64(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;
|
||||
let params = (elem_count, arg, &out);
|
||||
let out = unsafe { dev.0.alloc::<f64>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(Self::F64(out))
|
||||
|
Reference in New Issue
Block a user