Im2col cuda optimization. (#2885)

This commit is contained in:
Laurent Mazare
2025-04-13 10:07:53 +02:00
committed by GitHub
parent 15ed0b11ce
commit d9198deb37
2 changed files with 21 additions and 21 deletions

View File

@ -157,15 +157,15 @@ impl Map1 for Im2Col1D {
let shape = layout.shape();
let dims = shape.dims();
let l_out = self.l_out(dims[2]);
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let threads = dims[0] * l_out * dims[1];
let cfg = LaunchConfig::for_num_elems(threads as u32);
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), &kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el)? };
let dst = unsafe { dev.alloc::<T>(threads * self.l_k)? };
let mut builder = func.builder();
barg!(builder, dst_el);
barg!(builder, threads);
barg!(builder, l_out);
barg!(builder, self.l_k);
barg!(builder, self.stride);

View File

@ -53,7 +53,7 @@ __device__ void conv1d(
template <typename T>
__device__ void im2col1d(
const size_t dst_numel,
const size_t numel,
const size_t l_out,
const size_t l_k,
const size_t stride,
@ -63,10 +63,10 @@ __device__ void im2col1d(
const T *src,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
const size_t thread_i = blockIdx.x * blockDim.x + threadIdx.x;
// dst: (b_size, l_out, c_in, l_k)
// src: (b_size, c_in, l_in)
if (dst_i >= dst_numel) {
if (thread_i >= numel) {
return;
}
const size_t *src_dims = info;
@ -74,26 +74,26 @@ __device__ void im2col1d(
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
const size_t dst_s2 = l_k;
const size_t dst_s1 = c_in * dst_s2;
const size_t dst_s1 = c_in;
const size_t dst_s0 = l_out * dst_s1;
size_t tmp_dst_i = dst_i;
size_t tmp_dst_i = thread_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t l_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= l_idx * dst_s1;
const size_t c_idx = tmp_dst_i / dst_s2;
tmp_dst_i -= c_idx * dst_s2;
const size_t l_k_idx = tmp_dst_i;
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
dst[dst_i] = src[src_i];
const size_t c_idx = tmp_dst_i;
for (size_t l_k_idx = 0; l_k_idx < l_k; ++l_k_idx) {
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
size_t dst_i = thread_i * l_k + l_k_idx;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
dst[dst_i] = src[src_i];
}
}
}