mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Stride support for unary ops.
This commit is contained in:
@ -3,11 +3,16 @@
|
||||
#define UNARY_OP(TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
TYPENAME x = inp ? inp[strided_i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} \
|
||||
|
@ -167,19 +167,17 @@ impl CudaStorage {
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
if !shape.is_contiguous(stride) {
|
||||
return Err(CudaError::RequiresContiguous { op: "affine" });
|
||||
}
|
||||
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.0.htod_copy([dims, stride].concat())?;
|
||||
match self {
|
||||
Self::F32(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
|
||||
let params = (elem_count, arg, &out);
|
||||
let out = unsafe { dev.0.alloc::<f32>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(Self::F32(out))
|
||||
@ -187,8 +185,8 @@ impl CudaStorage {
|
||||
Self::F64(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;
|
||||
let params = (elem_count, arg, &out);
|
||||
let out = unsafe { dev.0.alloc::<f64>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(Self::F64(out))
|
||||
|
Reference in New Issue
Block a user