Im2col cuda optimization. (#2885)

This commit is contained in:
Laurent Mazare
2025-04-13 10:07:53 +02:00
committed by GitHub
parent 15ed0b11ce
commit d9198deb37
2 changed files with 21 additions and 21 deletions

View File

@ -157,15 +157,15 @@ impl Map1 for Im2Col1D {
let shape = layout.shape(); let shape = layout.shape();
let dims = shape.dims(); let dims = shape.dims();
let l_out = self.l_out(dims[2]); let l_out = self.l_out(dims[2]);
let dst_el = dims[0] * l_out * dims[1] * self.l_k; let threads = dims[0] * l_out * dims[1];
let cfg = LaunchConfig::for_num_elems(dst_el as u32); let cfg = LaunchConfig::for_num_elems(threads as u32);
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?; let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
let src = &src.slice(layout.start_offset()..); let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), &kernels::CONV)?; let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), &kernels::CONV)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el)? }; let dst = unsafe { dev.alloc::<T>(threads * self.l_k)? };
let mut builder = func.builder(); let mut builder = func.builder();
barg!(builder, dst_el); barg!(builder, threads);
barg!(builder, l_out); barg!(builder, l_out);
barg!(builder, self.l_k); barg!(builder, self.l_k);
barg!(builder, self.stride); barg!(builder, self.stride);

View File

@ -53,7 +53,7 @@ __device__ void conv1d(
template <typename T> template <typename T>
__device__ void im2col1d( __device__ void im2col1d(
const size_t dst_numel, const size_t numel,
const size_t l_out, const size_t l_out,
const size_t l_k, const size_t l_k,
const size_t stride, const size_t stride,
@ -63,10 +63,10 @@ __device__ void im2col1d(
const T *src, const T *src,
T *dst T *dst
) { ) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; const size_t thread_i = blockIdx.x * blockDim.x + threadIdx.x;
// dst: (b_size, l_out, c_in, l_k) // dst: (b_size, l_out, c_in, l_k)
// src: (b_size, c_in, l_in) // src: (b_size, c_in, l_in)
if (dst_i >= dst_numel) { if (thread_i >= numel) {
return; return;
} }
const size_t *src_dims = info; const size_t *src_dims = info;
@ -74,19 +74,18 @@ __device__ void im2col1d(
const size_t c_in = src_dims[1]; const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2]; const size_t l_in = src_dims[2];
const size_t dst_s2 = l_k; const size_t dst_s1 = c_in;
const size_t dst_s1 = c_in * dst_s2;
const size_t dst_s0 = l_out * dst_s1; const size_t dst_s0 = l_out * dst_s1;
size_t tmp_dst_i = dst_i; size_t tmp_dst_i = thread_i;
const size_t b_idx = tmp_dst_i / dst_s0; const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0; tmp_dst_i -= b_idx * dst_s0;
const size_t l_idx = tmp_dst_i / dst_s1; const size_t l_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= l_idx * dst_s1; tmp_dst_i -= l_idx * dst_s1;
const size_t c_idx = tmp_dst_i / dst_s2; const size_t c_idx = tmp_dst_i;
tmp_dst_i -= c_idx * dst_s2; for (size_t l_k_idx = 0; l_k_idx < l_k; ++l_k_idx) {
const size_t l_k_idx = tmp_dst_i;
size_t src_l_idx = l_idx * stride + l_k_idx * dilation; size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
size_t dst_i = thread_i * l_k + l_k_idx;
if (src_l_idx < padding || src_l_idx >= l_in + padding) { if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[dst_i] = static_cast<T>(0); dst[dst_i] = static_cast<T>(0);
} }
@ -96,6 +95,7 @@ __device__ void im2col1d(
dst[dst_i] = src[src_i]; dst[dst_i] = src[src_i];
} }
} }
}
template <typename T> template <typename T>
__device__ void col2im1d( __device__ void col2im1d(