From 1cd74129d47745a292c5f54ae3f9a1a1348cdf3e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 11 Sep 2023 08:52:33 +0100 Subject: [PATCH] Add Im2Col support on the gpu side. (#808) * Add Im2Col support on the gpu side. * Actually enable. --- candle-core/src/cuda_backend.rs | 35 +++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index b4bdc6cf..dbfaa928 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -11,6 +11,8 @@ use cudarc::driver::{ use half::{bf16, f16}; use std::sync::{Arc, Mutex}; +const USE_IM2COL_CONV2D: bool = true; + /// cudarc related errors #[derive(thiserror::Error, Debug)] pub enum CudaError { @@ -1723,8 +1725,37 @@ impl BackendStorage for CudaStorage { params: &crate::conv::ParamsConv2D, ) -> Result { let device = self.device().clone(); - let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; - Ok(Self { slice, device }) + if !USE_IM2COL_CONV2D { + let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + return Ok(Self { slice, device }); + } + + let col = Im2Col { + h_k: params.k_h, + w_k: params.k_w, + stride: params.stride, + dilation: params.dilation, + padding: params.padding, + } + .map(&self.slice, &device, l)?; + let col = Self { slice: col, device }; + let h_out = params.out_h(); + let w_out = params.out_w(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_h * params.k_w * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + let res = col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?; + let res_l = Layout::contiguous((b, h_out, w_out, n)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } #[cfg(feature = "cudnn")]