From 98d1242b8fd917baa95c9143252962f8fad3ebf7 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 10 Sep 2023 21:02:42 +0100 Subject: [PATCH] im2col based conv2d (#802) * im2col implementation for conv2d. * Fix for the im2col implementation to match the current conv2d. * Small optimization. * Add a cuda kernel. * Handle arbitrary layouts. * Im2Col cuda code. --- candle-core/src/cuda_backend.rs | 52 ++++++++++++++++ candle-kernels/src/conv.cu | 89 ++++++++++++++++++++++++++++ candle-nn/examples/cpu_benchmarks.rs | 85 +++++++++++++++++++++----- 3 files changed, 210 insertions(+), 16 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 7cc85489..b4bdc6cf 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -600,6 +600,58 @@ impl Map1 for Elu { } } +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + +impl Map1 for Im2Col { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let (h_out, w_out) = self.hw_out(dims[2], dims[3]); + let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("im2col"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(dst_el) }.w()?; + let params = ( + dst_el, + h_out, + w_out, + self.h_k, + self.w_k, + self.stride, + self.padding, + self.dilation, + &ds, + src, + &dst, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } +} + struct Powf(f64); impl Map1 for Powf { fn f( diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index ba2fa1ad..51c393cb 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -51,6 +51,71 @@ __device__ void conv1d( dst[dst_i] = static_cast(d); } +template +__device__ void im2col( + const size_t dst_numel, + const size_t h_out, + const size_t w_out, + const size_t h_k, + const size_t w_k, + const size_t stride, + const size_t padding, + const size_t dilation, + const size_t *info, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // dst: (b_size, h_out, w_out, c_in, h_k, w_k) + // src: (b_size, c_in, h_in, w_in) + if (dst_i >= dst_numel) { + return; + } + const size_t *src_dims = info; + const size_t *src_s = info + 4; + const size_t b_in = src_dims[0]; + const size_t c_in = src_dims[1]; + const size_t h_in = src_dims[2]; + const size_t w_in = src_dims[3]; + + const size_t dst_s4 = w_k; + const size_t dst_s3 = h_k * dst_s4; + const size_t dst_s2 = c_in * dst_s3; + const size_t dst_s1 = w_out * dst_s2; + const size_t dst_s0 = h_out * dst_s1; + + size_t tmp_dst_i = dst_i; + const size_t b_idx = tmp_dst_i / dst_s0; + tmp_dst_i -= b_idx * dst_s0; + const size_t h_idx = tmp_dst_i / dst_s1; + tmp_dst_i -= h_idx * dst_s1; + const size_t w_idx = tmp_dst_i / dst_s2; + tmp_dst_i -= w_idx * dst_s2; + const size_t c_idx = tmp_dst_i / dst_s3; + tmp_dst_i -= c_idx * dst_s3; + const size_t h_k_idx = tmp_dst_i / dst_s4; + tmp_dst_i -= h_k_idx * dst_s4; + const size_t w_k_idx = tmp_dst_i; + size_t src_h_idx = h_idx * stride + h_k_idx * dilation; + size_t src_w_idx = w_idx * stride + w_k_idx * dilation; + if (src_h_idx < padding || src_h_idx >= h_in + padding) { + dst[dst_i] = static_cast(0); + } + else if (src_w_idx < padding || src_w_idx >= w_in + padding) { + dst[dst_i] = static_cast(0); + } + else { + src_h_idx -= padding; + src_w_idx -= padding; + const size_t src_i = + b_idx * src_s[0] + + c_idx * src_s[1] + + src_h_idx * src_s[2] + + src_w_idx * src_s[3]; + dst[dst_i] = src[src_i]; + } +} + // Naive implementation of conv2d. template __device__ void conv2d( @@ -363,6 +428,23 @@ extern "C" __global__ void FN_NAME( \ conv2d(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \ } \ +#define IM2COL_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t dst_numel, \ + const size_t h_out, \ + const size_t w_out, \ + const size_t h_k, \ + const size_t w_k, \ + const size_t stride, \ + const size_t padding, \ + const size_t dilation, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + im2col(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \ +} \ + #define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t src_numel, \ @@ -428,6 +510,7 @@ CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16) AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16) MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16) UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) +IM2COL_OP(__nv_bfloat16, im2col_bf16) #endif #if __CUDA_ARCH__ >= 530 @@ -437,6 +520,7 @@ CONVT2D_OP(__half, float, conv_transpose2d_f16) AVG_POOL2D_OP(__half, float, avg_pool2d_f16) MAX_POOL2D_OP(__half, max_pool2d_f16) UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16) +IM2COL_OP(__half, im2col_f16) #endif CONV1D_OP(float, float, conv1d_f32) @@ -468,3 +552,8 @@ UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32) UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64) UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8) UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32) + +IM2COL_OP(float, im2col_f32) +IM2COL_OP(double, im2col_f64) +IM2COL_OP(uint8_t, im2col_u8) +IM2COL_OP(uint32_t, im2col_u32) diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 3ba30f94..5fb99625 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -9,6 +9,8 @@ use candle::quantized::GgmlType; use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D}; use clap::{Parser, Subcommand}; +const CHECK_CONV2D: bool = false; + trait Benchmark { type PreProcessData; type RunResult; @@ -19,25 +21,51 @@ trait Benchmark { const ITERS: usize; } -struct Im2Col(usize, usize); +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + impl candle::CustomOp1 for Im2Col { fn name(&self) -> &'static str { "im2col" } fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { - let &Self(h_k, w_k) = self; + let &Self { + h_k, + w_k, + stride, + dilation, + padding, + } = self; let (b, c, h, w) = layout.shape().dims4()?; - let (h_out, w_out) = (h - h_k + 1, w - w_k + 1); + let (h_out, w_out) = self.hw_out(h, w); let slice = storage.as_slice::()?; - let src = match layout.contiguous_offsets() { - None => candle::bail!("input has to be contiguous"), - Some((o1, o2)) => &slice[o1..o2], - }; + let src = &slice[layout.start_offset()..]; let mut dst = vec![0f32; b * h_out * w_out * c * h_k * w_k]; - let (s_b, s_c, s_h) = (c * h * w, h * w, w); + let (src_s0, src_s1, src_s2, src_s3) = { + let s = layout.stride(); + (s[0], s[1], s[2], s[3]) + }; + // TODO: provide specialized kernels for the common use cases. + // - h_k = w_k = 1 + // - padding = 0 + // - stride = 1 + // - dilation = 1 for b_idx in 0..b { - let src_idx = b_idx * s_b; + let src_idx = b_idx * src_s0; let dst_idx = b_idx * h_out * w_out * c * h_k * w_k; for h_idx in 0..h_out { let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k; @@ -45,12 +73,25 @@ impl candle::CustomOp1 for Im2Col { let dst_idx = dst_idx + w_idx * c * h_k * w_k; for c_idx in 0..c { let dst_idx = dst_idx + c_idx * h_k * w_k; - let src_idx = c_idx * s_c + src_idx; + let src_idx = c_idx * src_s1 + src_idx; for h_k_idx in 0..h_k { - let src_idx = src_idx + (h_idx + h_k_idx) * s_h + w_idx; + let src_h = h_idx * stride + h_k_idx * dilation; + if padding != 0 && (src_h < padding || src_h >= h + padding) { + continue; + } + let src_h = src_h - padding; + let src_idx = src_idx + src_h * src_s2; let dst_idx = dst_idx + h_k_idx * w_k; - dst[dst_idx..dst_idx + w_k] - .copy_from_slice(&src[src_idx..src_idx + w_k]) + for w_k_idx in 0..w_k { + let src_w = w_idx * stride + w_k_idx * dilation; + if padding != 0 && (src_w < padding || src_w >= h + padding) { + continue; + } + let src_w = src_w - padding; + let src_idx = src_idx + src_w * src_s3; + let dst_idx = dst_idx + w_k_idx; + dst[dst_idx] = src[src_idx] + } } } } @@ -113,14 +154,26 @@ impl Benchmark for Conv2dIm2Col { fn run_one(d: &Self::PreProcessData) -> Result { // d.0.conv2d(&d.1, 0, 1, 1, 1) let (b, _, h, w) = d.0.dims4()?; - let (h_k, w_k) = (3, 3); - let (h_out, w_out) = (h - h_k + 1, w - w_k + 1); - let col = d.0.apply_op1_no_bwd(&Im2Col(h_k, w_k))?; + let (_, _, h_k, w_k) = d.1.dims4()?; + let op = Im2Col { + h_k, + w_k, + stride: 1, + dilation: 1, + padding: 0, + }; + let (h_out, w_out) = op.hw_out(h, w); + let col = d.0.apply_op1_no_bwd(&op)?; let res = col.matmul(&d.1.flatten_from(1)?.t()?)?; let res = res .reshape((b, h_out, w_out, ()))? .permute((0, 3, 1, 2))? .contiguous()?; + if CHECK_CONV2D { + let res2 = d.0.conv2d(&d.1, op.padding, op.stride, op.dilation, 1); + let diff = (&res - res2)?.sqr()?.mean_all()?; + println!("{diff}"); + } Ok(res) }