mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Optimize the unary cuda kernels for the contiguous case.
This commit is contained in:
@ -6,22 +6,28 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *x, \
|
||||
TYPENAME *y, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const TYPENAME mul, \
|
||||
const TYPENAME add \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
||||
if (i >= numel) { \
|
||||
return; \
|
||||
} \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
y[strided_i] = x[i] * mul + add; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = x * mul + add; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
TYPENAME x = inp ? inp[strided_i] : out[i]; \
|
||||
out[i] = x * mul + add; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
AFFINE_OP(float, affine_f32)
|
||||
AFFINE_OP(double, affine_f64)
|
||||
AFFINE_OP(uint32_t, affine_u32)
|
||||
|
||||
|
@ -1,6 +1,22 @@
|
||||
#include "cuda_fp16.h"
|
||||
#include "compatibility.cuh"
|
||||
|
||||
__device__ bool is_contiguous(
|
||||
const size_t num_dims,
|
||||
const size_t *dims,
|
||||
const size_t *strides
|
||||
) {
|
||||
size_t acc = 1;
|
||||
for (unsigned int d = 0; d < num_dims; d++) {
|
||||
unsigned int dim_idx = num_dims - 1 - d;
|
||||
if (acc != strides[dim_idx]) {
|
||||
return false;
|
||||
}
|
||||
acc *= dims[dim_idx];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
__device__ unsigned int get_strided_index(
|
||||
unsigned int idx,
|
||||
const size_t num_dims,
|
||||
|
@ -10,10 +10,18 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
TYPENAME x = inp ? inp[strided_i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
TYPENAME x = inp ? inp[strided_i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
|
Reference in New Issue
Block a user