mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Im2col cuda optimization. (#2885)
This commit is contained in:
@ -157,15 +157,15 @@ impl Map1 for Im2Col1D {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let l_out = self.l_out(dims[2]);
|
||||
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let threads = dims[0] * l_out * dims[1];
|
||||
let cfg = LaunchConfig::for_num_elems(threads as u32);
|
||||
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), &kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el)? };
|
||||
let dst = unsafe { dev.alloc::<T>(threads * self.l_k)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, dst_el);
|
||||
barg!(builder, threads);
|
||||
barg!(builder, l_out);
|
||||
barg!(builder, self.l_k);
|
||||
barg!(builder, self.stride);
|
||||
|
@ -53,7 +53,7 @@ __device__ void conv1d(
|
||||
|
||||
template <typename T>
|
||||
__device__ void im2col1d(
|
||||
const size_t dst_numel,
|
||||
const size_t numel,
|
||||
const size_t l_out,
|
||||
const size_t l_k,
|
||||
const size_t stride,
|
||||
@ -63,10 +63,10 @@ __device__ void im2col1d(
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t thread_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// dst: (b_size, l_out, c_in, l_k)
|
||||
// src: (b_size, c_in, l_in)
|
||||
if (dst_i >= dst_numel) {
|
||||
if (thread_i >= numel) {
|
||||
return;
|
||||
}
|
||||
const size_t *src_dims = info;
|
||||
@ -74,26 +74,26 @@ __device__ void im2col1d(
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
|
||||
const size_t dst_s2 = l_k;
|
||||
const size_t dst_s1 = c_in * dst_s2;
|
||||
const size_t dst_s1 = c_in;
|
||||
const size_t dst_s0 = l_out * dst_s1;
|
||||
|
||||
size_t tmp_dst_i = dst_i;
|
||||
size_t tmp_dst_i = thread_i;
|
||||
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||
tmp_dst_i -= b_idx * dst_s0;
|
||||
const size_t l_idx = tmp_dst_i / dst_s1;
|
||||
tmp_dst_i -= l_idx * dst_s1;
|
||||
const size_t c_idx = tmp_dst_i / dst_s2;
|
||||
tmp_dst_i -= c_idx * dst_s2;
|
||||
const size_t l_k_idx = tmp_dst_i;
|
||||
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
|
||||
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
}
|
||||
else {
|
||||
src_l_idx -= padding;
|
||||
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
|
||||
dst[dst_i] = src[src_i];
|
||||
const size_t c_idx = tmp_dst_i;
|
||||
for (size_t l_k_idx = 0; l_k_idx < l_k; ++l_k_idx) {
|
||||
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
|
||||
size_t dst_i = thread_i * l_k + l_k_idx;
|
||||
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
}
|
||||
else {
|
||||
src_l_idx -= padding;
|
||||
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
|
||||
dst[dst_i] = src[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user