mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Cuda backend optimization (#1886)
* Attempt at making the kernel faster. * Also adapt the cast kernels. * Also apply to binary ops.
This commit is contained in:
@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = x * mul + add; \
|
||||
|
@ -12,8 +12,8 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t *dims = dims_and_strides; \
|
||||
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
|
||||
const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \
|
||||
bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \
|
||||
bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \
|
||||
bool lhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, lhs_strides); \
|
||||
bool rhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, rhs_strides); \
|
||||
if (lhs_cont && rhs_cont) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = lhs[i]; \
|
||||
|
@ -11,7 +11,7 @@ __device__ void cast_(
|
||||
) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
if (is_contiguous(num_dims, dims, strides)) {
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
out[i] = inp[i];
|
||||
}
|
||||
@ -34,7 +34,7 @@ __device__ void cast_through(
|
||||
) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
if (is_contiguous(num_dims, dims, strides)) {
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
out[i] = static_cast<T>(static_cast<I>(inp[i]));
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
@ -71,7 +71,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
|
Reference in New Issue
Block a user