mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add Im2Col support on the gpu side. (#808)
* Add Im2Col support on the gpu side. * Actually enable.
This commit is contained in:
@ -11,6 +11,8 @@ use cudarc::driver::{
|
||||
use half::{bf16, f16};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
/// cudarc related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum CudaError {
|
||||
@ -1723,8 +1725,37 @@ impl BackendStorage for CudaStorage {
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
if !USE_IM2COL_CONV2D {
|
||||
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
|
||||
let col = Im2Col {
|
||||
h_k: params.k_h,
|
||||
w_k: params.k_w,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
padding: params.padding,
|
||||
}
|
||||
.map(&self.slice, &device, l)?;
|
||||
let col = Self { slice: col, device };
|
||||
let h_out = params.out_h();
|
||||
let w_out = params.out_w();
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let k = params.k_h * params.k_w * params.c_in;
|
||||
let m = h_out * w_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
let res = col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?;
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||
.transpose(1, 2)?
|
||||
.transpose(1, 3)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
#[cfg(feature = "cudnn")]
|
||||
|
Reference in New Issue
Block a user