From 6fb665004ca340808c049514e738ac1e814d9a23 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 11 Sep 2023 09:28:13 +0100 Subject: [PATCH] Enable im2col on the cpu side. (#805) * Enable im2col on the cpu side. * Hook im2col on the cpu backend. * Use the kernel offset. * Avoid an unnecessary copy. * Handle non-contiguous kernels. * Add a const to select the conv2d kernel. --- candle-core/src/cpu_backend.rs | 115 ++++++++++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 1 deletion(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 01ccfde7..3cdc538a 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -4,6 +4,8 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; use rayon::prelude::*; +const USE_IM2COL_CONV2D: bool = true; + // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // intercept the oom errors to avoid panicking and provide a proper error. #[derive(Debug, Clone)] @@ -1089,6 +1091,81 @@ impl<'a> Map2 for Conv1D<'a> { } } +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + +impl Map1 for Im2Col { + fn f(&self, vs: &[T], layout: &Layout) -> Result> { + let &Self { + h_k, + w_k, + stride, + dilation, + padding, + } = self; + let (b, c, h, w) = layout.shape().dims4()?; + let (h_out, w_out) = self.hw_out(h, w); + let src = &vs[layout.start_offset()..]; + let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k]; + let (src_s0, src_s1, src_s2, src_s3) = { + let s = layout.stride(); + (s[0], s[1], s[2], s[3]) + }; + // TODO: provide specialized kernels for the common use cases. + // - h_k = w_k = 1 + // - padding = 0 + // - stride = 1 + // - dilation = 1 + for b_idx in 0..b { + let src_idx = b_idx * src_s0; + let dst_idx = b_idx * h_out * w_out * c * h_k * w_k; + for h_idx in 0..h_out { + let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k; + for w_idx in 0..w_out { + let dst_idx = dst_idx + w_idx * c * h_k * w_k; + for c_idx in 0..c { + let dst_idx = dst_idx + c_idx * h_k * w_k; + let src_idx = c_idx * src_s1 + src_idx; + for h_k_idx in 0..h_k { + let src_h = h_idx * stride + h_k_idx * dilation; + if padding != 0 && (src_h < padding || src_h >= h + padding) { + continue; + } + let src_h = src_h - padding; + let src_idx = src_idx + src_h * src_s2; + let dst_idx = dst_idx + h_k_idx * w_k; + for w_k_idx in 0..w_k { + let src_w = w_idx * stride + w_k_idx * dilation; + if padding != 0 && (src_w < padding || src_w >= h + padding) { + continue; + } + let src_w = src_w - padding; + let src_idx = src_idx + src_w * src_s3; + let dst_idx = dst_idx + w_k_idx; + dst[dst_idx] = src[src_idx] + } + } + } + } + } + } + Ok(dst) + } +} + struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); impl<'a> Map2 for Conv2D<'a> { @@ -2237,7 +2314,43 @@ impl BackendStorage for CpuStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv2D, ) -> Result { - Conv2D(params).map(self, l, kernel, kernel_l) + if !USE_IM2COL_CONV2D { + return Conv2D(params).map(self, l, kernel, kernel_l); + } + let op = Im2Col { + h_k: params.k_h, + w_k: params.k_w, + padding: params.padding, + stride: params.stride, + dilation: params.dilation, + }; + let col = op.map(self, l)?; + let b = params.b_size; + let n = params.c_out; + let (h_out, w_out) = (params.out_h(), params.out_w()); + let k = op.h_k * op.w_k * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, h_out, w_out, params.c_out)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } fn conv_transpose2d(