mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a cuda kernel for avg-pool2d. (#440)
* Add a cuda kernel for avg-pool2d. * Avoid running out of bounds. * Finish wiring the avg pool kernel + add some testing. * Support for max-pool + testing.
This commit is contained in:
@ -24,6 +24,9 @@ __device__ void conv1d(
|
||||
const size_t c_out = k_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
if (dst_i >= src_dims[0] * c_out * l_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO
|
||||
const size_t b_idx = dst_i / (l_out * c_out);
|
||||
@ -61,9 +64,6 @@ __device__ void conv2d(
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (dst_i >= src_numel) {
|
||||
return;
|
||||
}
|
||||
// src: (b_size, c_in, w_in, h_in)
|
||||
// k: (c_out, c_in, w_k, h_k)
|
||||
const size_t *src_dims = info;
|
||||
@ -76,6 +76,9 @@ __device__ void conv2d(
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t w_in = src_dims[2];
|
||||
const size_t h_in = src_dims[3];
|
||||
if (dst_i >= src_dims[0] * c_out * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO
|
||||
const size_t b_idx = dst_i / (w_out * h_out * c_out);
|
||||
@ -107,6 +110,116 @@ __device__ void conv2d(
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
template <typename T, typename A>
|
||||
__device__ void avg_pool2d(
|
||||
const size_t src_numel,
|
||||
const size_t w_k,
|
||||
const size_t h_k,
|
||||
const size_t w_stride,
|
||||
const size_t h_stride,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, c_in, w_in, h_in)
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 4;
|
||||
|
||||
const size_t c = src_dims[1];
|
||||
const size_t w_in = src_dims[2];
|
||||
const size_t h_in = src_dims[3];
|
||||
|
||||
const size_t w_out = (w_in - w_k) / w_stride + 1;
|
||||
const size_t h_out = (h_in - h_k) / h_stride + 1;
|
||||
if (dst_i >= src_dims[0] * c * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: Improve this.
|
||||
const size_t b_idx = dst_i / (w_out * h_out * c);
|
||||
const size_t c_idx = (dst_i / (w_out * h_out)) % c;
|
||||
const size_t dst_w = (dst_i / h_out) % w_out;
|
||||
const size_t dst_h = dst_i % h_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
const float scale = 1.0 / (w_k * h_k);
|
||||
A d = 0;
|
||||
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
|
||||
size_t src_w = w_stride * dst_w + w_offset;
|
||||
if (src_w >= w_in) {
|
||||
continue;
|
||||
}
|
||||
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
|
||||
size_t src_h = h_stride * dst_h + h_offset;
|
||||
if (src_h >= h_in) {
|
||||
continue;
|
||||
}
|
||||
const size_t src_idx = src_idx0 + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
|
||||
d += static_cast<A>(src[src_idx]);
|
||||
}
|
||||
}
|
||||
dst[dst_i] = static_cast<T>(d * scale);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void max_pool2d(
|
||||
const size_t src_numel,
|
||||
const size_t w_k,
|
||||
const size_t h_k,
|
||||
const size_t w_stride,
|
||||
const size_t h_stride,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, c_in, w_in, h_in)
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 4;
|
||||
|
||||
const size_t c = src_dims[1];
|
||||
const size_t w_in = src_dims[2];
|
||||
const size_t h_in = src_dims[3];
|
||||
|
||||
const size_t w_out = (w_in - w_k) / w_stride + 1;
|
||||
const size_t h_out = (h_in - h_k) / h_stride + 1;
|
||||
if (dst_i >= src_dims[0] * c * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: Improve this.
|
||||
const size_t b_idx = dst_i / (w_out * h_out * c);
|
||||
const size_t c_idx = (dst_i / (w_out * h_out)) % c;
|
||||
const size_t dst_w = (dst_i / h_out) % w_out;
|
||||
const size_t dst_h = dst_i % h_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
T d = 0;
|
||||
bool set = false;
|
||||
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
|
||||
size_t src_w = w_stride * dst_w + w_offset;
|
||||
if (src_w >= w_in) {
|
||||
continue;
|
||||
}
|
||||
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
|
||||
size_t src_h = h_stride * dst_h + h_offset;
|
||||
if (src_h >= h_in) {
|
||||
continue;
|
||||
}
|
||||
const size_t src_idx = src_idx0 + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
|
||||
if (set) {
|
||||
d = maxg(d, src[src_idx]);
|
||||
}
|
||||
else {
|
||||
d = src[src_idx];
|
||||
set = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[dst_i] = d;
|
||||
}
|
||||
|
||||
|
||||
#define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
@ -137,14 +250,46 @@ extern "C" __global__ void FN_NAME( \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t w_k, \
|
||||
const size_t h_k, \
|
||||
const size_t w_stride, \
|
||||
const size_t h_stride, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
avg_pool2d<TYPENAME, TYPEACC>(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \
|
||||
} \
|
||||
|
||||
#define MAX_POOL2D_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t w_k, \
|
||||
const size_t h_k, \
|
||||
const size_t w_stride, \
|
||||
const size_t h_stride, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
max_pool2d<TYPENAME>(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
|
||||
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
|
||||
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
CONV1D_OP(__half, float, conv1d_f16)
|
||||
CONV2D_OP(__half, float, conv2d_f16)
|
||||
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
@ -157,3 +302,12 @@ CONV2D_OP(double, double, conv2d_f64)
|
||||
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
|
||||
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
|
||||
|
||||
AVG_POOL2D_OP(float, float, avg_pool2d_f32)
|
||||
AVG_POOL2D_OP(double, double, avg_pool2d_f64)
|
||||
AVG_POOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8)
|
||||
AVG_POOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32)
|
||||
|
||||
MAX_POOL2D_OP(float, max_pool2d_f32)
|
||||
MAX_POOL2D_OP(double, max_pool2d_f64)
|
||||
MAX_POOL2D_OP(uint8_t, max_pool2d_u8)
|
||||
MAX_POOL2D_OP(uint32_t, max_pool2d_u32)
|
||||
|
Reference in New Issue
Block a user