mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Cuda kernels for fast min/max reductions (#203)
* Add the min/max cuda kernels. * Better integration of the cuda kernels.
This commit is contained in:
@ -515,8 +515,8 @@ impl<'a> Map1 for Sum<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct FastSum<'a>(&'a [usize]);
|
||||
impl<'a> Map1 for FastSum<'a> {
|
||||
struct FastReduce<'a>(&'a [usize], crate::op::ReduceOp);
|
||||
impl<'a> Map1 for FastReduce<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -557,8 +557,14 @@ impl<'a> Map1 for FastSum<'a> {
|
||||
.htod_copy([dims.as_slice(), stride.as_slice()].concat())
|
||||
.w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("fast_sum"), kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<T>(dst_el).w()?;
|
||||
let name = match self.1 {
|
||||
crate::op::ReduceOp::Sum => "fast_sum",
|
||||
crate::op::ReduceOp::Min => "fast_min",
|
||||
crate::op::ReduceOp::Max => "fast_max",
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
@ -961,15 +967,9 @@ impl BackendStorage for CudaStorage {
|
||||
layout: &Layout,
|
||||
sum_dims: &[usize],
|
||||
) -> Result<Self> {
|
||||
match op {
|
||||
crate::op::ReduceOp::Sum => {
|
||||
let device = self.device().clone();
|
||||
let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
crate::op::ReduceOp::Min => Err(CudaError::InternalError("TODO: implement min").into()),
|
||||
crate::op::ReduceOp::Max => Err(CudaError::InternalError("TODO: implement max").into()),
|
||||
}
|
||||
let device = self.device().clone();
|
||||
let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
|
Reference in New Issue
Block a user