mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add Im2Col support on the gpu side. (#808)
* Add Im2Col support on the gpu side. * Actually enable.
This commit is contained in:
@ -11,6 +11,8 @@ use cudarc::driver::{
|
|||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
/// cudarc related errors
|
/// cudarc related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum CudaError {
|
pub enum CudaError {
|
||||||
@ -1723,8 +1725,37 @@ impl BackendStorage for CudaStorage {
|
|||||||
params: &crate::conv::ParamsConv2D,
|
params: &crate::conv::ParamsConv2D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
|
if !USE_IM2COL_CONV2D {
|
||||||
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||||
Ok(Self { slice, device })
|
return Ok(Self { slice, device });
|
||||||
|
}
|
||||||
|
|
||||||
|
let col = Im2Col {
|
||||||
|
h_k: params.k_h,
|
||||||
|
w_k: params.k_w,
|
||||||
|
stride: params.stride,
|
||||||
|
dilation: params.dilation,
|
||||||
|
padding: params.padding,
|
||||||
|
}
|
||||||
|
.map(&self.slice, &device, l)?;
|
||||||
|
let col = Self { slice: col, device };
|
||||||
|
let h_out = params.out_h();
|
||||||
|
let w_out = params.out_w();
|
||||||
|
let b = params.b_size;
|
||||||
|
let n = params.c_out;
|
||||||
|
let k = params.k_h * params.k_w * params.c_in;
|
||||||
|
let m = h_out * w_out;
|
||||||
|
let col_l = Layout::contiguous((b, m, k));
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
let res = col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?;
|
||||||
|
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.transpose(1, 3)?;
|
||||||
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cudnn")]
|
#[cfg(feature = "cudnn")]
|
||||||
|
Reference in New Issue
Block a user