diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 62b0bd15..9d5a76b5 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -157,15 +157,15 @@ impl Map1 for Im2Col1D { let shape = layout.shape(); let dims = shape.dims(); let l_out = self.l_out(dims[2]); - let dst_el = dims[0] * l_out * dims[1] * self.l_k; - let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let threads = dims[0] * l_out * dims[1]; + let cfg = LaunchConfig::for_num_elems(threads as u32); let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("im2col1d"), &kernels::CONV)?; // SAFETY: Set later by running the kernel. - let dst = unsafe { dev.alloc::(dst_el)? }; + let dst = unsafe { dev.alloc::(threads * self.l_k)? }; let mut builder = func.builder(); - barg!(builder, dst_el); + barg!(builder, threads); barg!(builder, l_out); barg!(builder, self.l_k); barg!(builder, self.stride); diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index fa834faa..53569e2d 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -53,7 +53,7 @@ __device__ void conv1d( template __device__ void im2col1d( - const size_t dst_numel, + const size_t numel, const size_t l_out, const size_t l_k, const size_t stride, @@ -63,10 +63,10 @@ __device__ void im2col1d( const T *src, T *dst ) { - const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + const size_t thread_i = blockIdx.x * blockDim.x + threadIdx.x; // dst: (b_size, l_out, c_in, l_k) // src: (b_size, c_in, l_in) - if (dst_i >= dst_numel) { + if (thread_i >= numel) { return; } const size_t *src_dims = info; @@ -74,26 +74,26 @@ __device__ void im2col1d( const size_t c_in = src_dims[1]; const size_t l_in = src_dims[2]; - const size_t dst_s2 = l_k; - const size_t dst_s1 = c_in * dst_s2; + const size_t dst_s1 = c_in; const size_t dst_s0 = l_out * dst_s1; - size_t tmp_dst_i = dst_i; + size_t tmp_dst_i = thread_i; const size_t b_idx = tmp_dst_i / dst_s0; tmp_dst_i -= b_idx * dst_s0; const size_t l_idx = tmp_dst_i / dst_s1; tmp_dst_i -= l_idx * dst_s1; - const size_t c_idx = tmp_dst_i / dst_s2; - tmp_dst_i -= c_idx * dst_s2; - const size_t l_k_idx = tmp_dst_i; - size_t src_l_idx = l_idx * stride + l_k_idx * dilation; - if (src_l_idx < padding || src_l_idx >= l_in + padding) { - dst[dst_i] = static_cast(0); - } - else { - src_l_idx -= padding; - const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2]; - dst[dst_i] = src[src_i]; + const size_t c_idx = tmp_dst_i; + for (size_t l_k_idx = 0; l_k_idx < l_k; ++l_k_idx) { + size_t src_l_idx = l_idx * stride + l_k_idx * dilation; + size_t dst_i = thread_i * l_k + l_k_idx; + if (src_l_idx < padding || src_l_idx >= l_in + padding) { + dst[dst_i] = static_cast(0); + } + else { + src_l_idx -= padding; + const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2]; + dst[dst_i] = src[src_i]; + } } }