mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Also optimize the contiguous case for the binary cuda kernels.
This commit is contained in:
@ -12,18 +12,54 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
const size_t *dims = dims_and_strides; \
|
const size_t *dims = dims_and_strides; \
|
||||||
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
|
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
|
||||||
const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \
|
const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \
|
||||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \
|
||||||
unsigned int tmp_i = i; \
|
bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \
|
||||||
unsigned int lhs_i = 0; \
|
if (lhs_cont && rhs_cont) { \
|
||||||
unsigned int rhs_i = 0; \
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
for (int d = num_dims - 1; d >= 0; d--) { \
|
TYPENAME x = lhs ? lhs[i] : out[i]; \
|
||||||
unsigned int i_dim = tmp_i % dims[d]; \
|
TYPENAME y = rhs ? rhs[i] : out[i]; \
|
||||||
lhs_i += i_dim * lhs_strides[d]; \
|
out[i] = FUNC; \
|
||||||
rhs_i += i_dim * rhs_strides[d]; \
|
} \
|
||||||
tmp_i /= dims[d]; \
|
} else if (lhs_cont) { \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
unsigned int tmp_i = i; \
|
||||||
|
unsigned int rhs_i = 0; \
|
||||||
|
for (int d = num_dims - 1; d >= 0; d--) { \
|
||||||
|
unsigned int i_dim = tmp_i % dims[d]; \
|
||||||
|
rhs_i += i_dim * rhs_strides[d]; \
|
||||||
|
tmp_i /= dims[d]; \
|
||||||
|
} \
|
||||||
|
TYPENAME x = lhs ? lhs[i] : out[i]; \
|
||||||
|
TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \
|
||||||
|
out[i] = FUNC; \
|
||||||
|
} \
|
||||||
|
} else if (rhs_cont) { \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
unsigned int tmp_i = i; \
|
||||||
|
unsigned int lhs_i = 0; \
|
||||||
|
for (int d = num_dims - 1; d >= 0; d--) { \
|
||||||
|
unsigned int i_dim = tmp_i % dims[d]; \
|
||||||
|
lhs_i += i_dim * lhs_strides[d]; \
|
||||||
|
tmp_i /= dims[d]; \
|
||||||
|
} \
|
||||||
|
TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \
|
||||||
|
TYPENAME y = rhs ? rhs[i] : out[i]; \
|
||||||
|
out[i] = FUNC; \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
unsigned int tmp_i = i; \
|
||||||
|
unsigned int lhs_i = 0; \
|
||||||
|
unsigned int rhs_i = 0; \
|
||||||
|
for (int d = num_dims - 1; d >= 0; d--) { \
|
||||||
|
unsigned int i_dim = tmp_i % dims[d]; \
|
||||||
|
lhs_i += i_dim * lhs_strides[d]; \
|
||||||
|
rhs_i += i_dim * rhs_strides[d]; \
|
||||||
|
tmp_i /= dims[d]; \
|
||||||
|
} \
|
||||||
|
TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \
|
||||||
|
TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \
|
||||||
|
out[i] = FUNC; \
|
||||||
} \
|
} \
|
||||||
TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \
|
|
||||||
TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \
|
|
||||||
out[i] = FUNC; \
|
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
Reference in New Issue
Block a user