mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Reduce the number of threads.
This commit is contained in:
@ -116,13 +116,13 @@ __device__ void conv2d(
|
||||
// Naive implementation of conv_transpose2d.
|
||||
template <typename T, typename A>
|
||||
__device__ void conv_transpose2d(
|
||||
const size_t src_numel,
|
||||
const size_t w_out,
|
||||
const size_t h_out,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t out_padding,
|
||||
const size_t dilation,
|
||||
const size_t groups,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
@ -130,17 +130,18 @@ __device__ void conv_transpose2d(
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, c_in, h_in, w_in)
|
||||
// k: (c_in, c_out, h_k, w_k)
|
||||
// k: (c_in, c_out / groups, h_k, w_k)
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 4;
|
||||
const size_t *k_dims = info + 8;
|
||||
const size_t *k_s = info + 12;
|
||||
const size_t h_k = k_dims[2];
|
||||
const size_t w_k = k_dims[3];
|
||||
const size_t c_out = k_dims[1];
|
||||
const size_t c_out_per_group = k_dims[1];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t h_in = src_dims[2];
|
||||
const size_t w_in = src_dims[3];
|
||||
const size_t c_out = c_out_per_group * groups;
|
||||
if (dst_i >= src_dims[0] * c_out * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
@ -148,6 +149,10 @@ __device__ void conv_transpose2d(
|
||||
// TODO
|
||||
const size_t b_idx = dst_i / (w_out * h_out * c_out);
|
||||
const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out;
|
||||
const size_t c_idx_in_group = dst_c_idx % c_out_per_group;
|
||||
const size_t c_in_per_group = c_in / groups;
|
||||
const size_t group_idx = dst_c_idx / c_out_per_group;
|
||||
// const size_t c_in_per_group = c_in;
|
||||
// NCHW layout.
|
||||
const size_t out_y = (dst_i / w_out) % h_out;
|
||||
const size_t out_x = dst_i % w_out;
|
||||
@ -169,9 +174,9 @@ __device__ void conv_transpose2d(
|
||||
}
|
||||
int inp_y = inp_y_stride / stride;
|
||||
if (inp_y >= h_in) continue;
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
for (size_t src_c_idx = group_idx * c_in_per_group; src_c_idx < (group_idx + 1) * c_in_per_group; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_y * src_s[2] + inp_x * src_s[3];
|
||||
const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_y * k_s[2] + k_x * k_s[3];
|
||||
const size_t k_idx = src_c_idx * k_s[0] + c_idx_in_group * k_s[1] + k_y * k_s[2] + k_x * k_s[3];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
|
||||
}
|
||||
}
|
||||
@ -365,19 +370,19 @@ extern "C" __global__ void FN_NAME( \
|
||||
|
||||
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t w_out, \
|
||||
const size_t h_out, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t out_padding, \
|
||||
const size_t dilation, \
|
||||
const size_t groups, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
|
||||
conv_transpose2d<TYPENAME, TYPEACC>(w_out, h_out, stride, padding, out_padding, dilation, groups, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
|
Reference in New Issue
Block a user