Cuda kernels for fast min/max reductions (#203)

* Add the min/max cuda kernels.

* Better integration of the cuda kernels.
This commit is contained in:
Laurent Mazare
2023-07-19 19:12:27 +02:00
committed by GitHub
parent 001f9a59ce
commit 536c5e702e
3 changed files with 130 additions and 22 deletions

View File

@ -515,8 +515,8 @@ impl<'a> Map1 for Sum<'a> {
}
}
struct FastSum<'a>(&'a [usize]);
impl<'a> Map1 for FastSum<'a> {
struct FastReduce<'a>(&'a [usize], crate::op::ReduceOp);
impl<'a> Map1 for FastReduce<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
@ -557,8 +557,14 @@ impl<'a> Map1 for FastSum<'a> {
.htod_copy([dims.as_slice(), stride.as_slice()].concat())
.w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("fast_sum"), kernels::REDUCE)?;
let out = dev.alloc_zeros::<T>(dst_el).w()?;
let name = match self.1 {
crate::op::ReduceOp::Sum => "fast_sum",
crate::op::ReduceOp::Min => "fast_min",
crate::op::ReduceOp::Max => "fast_max",
};
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
@ -961,15 +967,9 @@ impl BackendStorage for CudaStorage {
layout: &Layout,
sum_dims: &[usize],
) -> Result<Self> {
match op {
crate::op::ReduceOp::Sum => {
let device = self.device().clone();
let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
crate::op::ReduceOp::Min => Err(CudaError::InternalError("TODO: implement min").into()),
crate::op::ReduceOp::Max => Err(CudaError::InternalError("TODO: implement max").into()),
}
let device = self.device().clone();
let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {