diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 7dfbb468..94abd37a 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,7 +1,9 @@ use crate::{CpuStorage, DType, Layout, Shape, WithDType}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; -use cudarc::driver::{CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig}; +use cudarc::driver::{ + CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, +}; use half::{bf16, f16}; use std::sync::Arc; @@ -244,7 +246,7 @@ enum CudaStorageSlice { } trait Map1 { - fn f( + fn f( &self, src: &CudaSlice, dev: &CudaDevice, @@ -276,7 +278,6 @@ impl Map1 for Clone { } struct Affine(f64, f64); - impl Map1 for Affine { fn f( &self, @@ -309,6 +310,43 @@ impl Map1 for Affine { } } +struct Sum<'a>(&'a [usize]); +impl<'a> Map1 for Sum<'a> { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let src_dims = shape.dims(); + let el = shape.elem_count(); + let mut dst_el = el; + for &sum_dim in self.0.iter() { + dst_el /= src_dims[sum_dim]; + } + let mut sum_dims = self.0.to_vec(); + // Sort the sum_dims as they have to be processed from left to right when converting the + // indexes. + sum_dims.sort(); + let sum_dims_l: Vec = sum_dims.iter().map(|&d| src_dims[d]).collect(); + let sum_dims_s: Vec = sum_dims + .iter() + .map(|&d| src_dims[d + 1..].iter().product::()) + .collect(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?; + let src = &src.slice(layout.start_offset()..); + let kernel_name = format!("sum_{}", T::DTYPE.as_str()); + let func = dev.get_or_load_func(&kernel_name, kernels::REDUCE)?; + let out = dev.alloc_zeros::(dst_el)?; + let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, @@ -486,73 +524,8 @@ impl CudaStorage { } pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { - let shape = layout.shape(); - let src_dims = shape.dims(); - let el = shape.elem_count(); - let mut dst_el = el; - for &sum_dim in sum_dims.iter() { - dst_el /= src_dims[sum_dim]; - } - let mut sum_dims = sum_dims.to_vec(); - // Sort the sum_dims as they have to be processed from left to right when converting the - // indexes. - sum_dims.sort(); - let sum_dims_l: Vec = sum_dims.iter().map(|&d| src_dims[d]).collect(); - let sum_dims_s: Vec = sum_dims - .iter() - .map(|&d| src_dims[d + 1..].iter().product::()) - .collect(); - let cfg = LaunchConfig::for_num_elems(el as u32); - let dev = self.device(); - let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?; - let slice = match &self.slice { - CudaStorageSlice::U32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - CudaStorageSlice::BF16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - CudaStorageSlice::F16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - CudaStorageSlice::F32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - CudaStorageSlice::F64(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = Sum(sum_dims).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) }