mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the reduce-sum kernel.
This commit is contained in:
@ -2,4 +2,5 @@ pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
|||||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||||
pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx"));
|
pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx"));
|
||||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||||
|
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||||
|
46
kernels/src/reduce.cu
Normal file
46
kernels/src/reduce.cu
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
#include "cuda_utils.cuh"
|
||||||
|
#include<stdint.h>
|
||||||
|
|
||||||
|
#define SUM_OP(TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const size_t numel, \
|
||||||
|
const size_t num_dims, \
|
||||||
|
const size_t num_sum_dims, \
|
||||||
|
const size_t *info, \
|
||||||
|
const TYPENAME *inp, \
|
||||||
|
TYPENAME *out \
|
||||||
|
) { \
|
||||||
|
const size_t *dims = info; \
|
||||||
|
const size_t *strides = info + num_dims; \
|
||||||
|
const size_t *sum_dims_l = info + 2*num_dims; \
|
||||||
|
const size_t *sum_dims_s = info + 2*num_dims + num_sum_dims; \
|
||||||
|
if (is_contiguous(num_dims, dims, strides)) { \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
size_t dst_index = i; \
|
||||||
|
for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
|
||||||
|
size_t stride = sum_dims_s[nd]; \
|
||||||
|
size_t pre = dst_index / stride; \
|
||||||
|
size_t post = dst_index / stride; \
|
||||||
|
dst_index = (pre / sum_dims_l[nd]) * stride + post; \
|
||||||
|
} \
|
||||||
|
out[dst_index] += inp[i]; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
else { \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||||
|
size_t dst_index = i; \
|
||||||
|
for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
|
||||||
|
size_t stride = sum_dims_s[nd]; \
|
||||||
|
size_t pre = dst_index / stride; \
|
||||||
|
size_t post = dst_index / stride; \
|
||||||
|
dst_index = (pre / sum_dims_l[nd]) * stride + post; \
|
||||||
|
} \
|
||||||
|
out[dst_index] += inp[strided_i]; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
SUM_OP(float, sum_f32)
|
||||||
|
SUM_OP(double, sum_f64)
|
||||||
|
SUM_OP(uint32_t, sum_u32)
|
@ -298,13 +298,53 @@ impl CudaStorage {
|
|||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn sum(
|
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result<Self> {
|
||||||
&self,
|
let src_dims = shape.dims();
|
||||||
_shape: &Shape,
|
let el = shape.elem_count();
|
||||||
_stride: &[usize],
|
let mut dst_el = el;
|
||||||
_sum_dims: &[usize],
|
for &sum_dim in sum_dims.iter() {
|
||||||
) -> Result<Self> {
|
dst_el /= src_dims[sum_dim];
|
||||||
todo!()
|
}
|
||||||
|
let mut sum_dims = sum_dims.to_vec();
|
||||||
|
// Sort the sum_dims as they have to be processed from left to right when converting the
|
||||||
|
// indexes.
|
||||||
|
sum_dims.sort();
|
||||||
|
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
|
||||||
|
let sum_dims_s: Vec<usize> = sum_dims
|
||||||
|
.iter()
|
||||||
|
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
|
||||||
|
.collect();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
|
let dev = self.device();
|
||||||
|
let ds = dev.htod_copy([src_dims, stride, &sum_dims_l, &sum_dims_s].concat())?;
|
||||||
|
let slice = match &self.slice {
|
||||||
|
CudaStorageSlice::U32(arg) => {
|
||||||
|
let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?;
|
||||||
|
let out = dev.alloc_zeros::<u32>(dst_el)?;
|
||||||
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::U32(out)
|
||||||
|
}
|
||||||
|
CudaStorageSlice::F32(arg) => {
|
||||||
|
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
|
||||||
|
let out = dev.alloc_zeros::<f32>(dst_el)?;
|
||||||
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F32(out)
|
||||||
|
}
|
||||||
|
CudaStorageSlice::F64(arg) => {
|
||||||
|
let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?;
|
||||||
|
let out = dev.alloc_zeros::<f64>(dst_el)?;
|
||||||
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F64(out)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let device = dev.clone();
|
||||||
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||||
|
Reference in New Issue
Block a user