mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Enable im2col on the cpu side. (#805)
* Enable im2col on the cpu side. * Hook im2col on the cpu backend. * Use the kernel offset. * Avoid an unnecessary copy. * Handle non-contiguous kernels. * Add a const to select the conv2d kernel.
This commit is contained in:
@ -4,6 +4,8 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
use rayon::prelude::*;
|
||||
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
// intercept the oom errors to avoid panicking and provide a proper error.
|
||||
#[derive(Debug, Clone)]
|
||||
@ -1089,6 +1091,81 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col {
|
||||
h_k: usize,
|
||||
w_k: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
impl Im2Col {
|
||||
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
|
||||
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
|
||||
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
|
||||
(h_out, w_out)
|
||||
}
|
||||
}
|
||||
|
||||
impl Map1 for Im2Col {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let &Self {
|
||||
h_k,
|
||||
w_k,
|
||||
stride,
|
||||
dilation,
|
||||
padding,
|
||||
} = self;
|
||||
let (b, c, h, w) = layout.shape().dims4()?;
|
||||
let (h_out, w_out) = self.hw_out(h, w);
|
||||
let src = &vs[layout.start_offset()..];
|
||||
let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
|
||||
let (src_s0, src_s1, src_s2, src_s3) = {
|
||||
let s = layout.stride();
|
||||
(s[0], s[1], s[2], s[3])
|
||||
};
|
||||
// TODO: provide specialized kernels for the common use cases.
|
||||
// - h_k = w_k = 1
|
||||
// - padding = 0
|
||||
// - stride = 1
|
||||
// - dilation = 1
|
||||
for b_idx in 0..b {
|
||||
let src_idx = b_idx * src_s0;
|
||||
let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
|
||||
for h_idx in 0..h_out {
|
||||
let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
|
||||
for w_idx in 0..w_out {
|
||||
let dst_idx = dst_idx + w_idx * c * h_k * w_k;
|
||||
for c_idx in 0..c {
|
||||
let dst_idx = dst_idx + c_idx * h_k * w_k;
|
||||
let src_idx = c_idx * src_s1 + src_idx;
|
||||
for h_k_idx in 0..h_k {
|
||||
let src_h = h_idx * stride + h_k_idx * dilation;
|
||||
if padding != 0 && (src_h < padding || src_h >= h + padding) {
|
||||
continue;
|
||||
}
|
||||
let src_h = src_h - padding;
|
||||
let src_idx = src_idx + src_h * src_s2;
|
||||
let dst_idx = dst_idx + h_k_idx * w_k;
|
||||
for w_k_idx in 0..w_k {
|
||||
let src_w = w_idx * stride + w_k_idx * dilation;
|
||||
if padding != 0 && (src_w < padding || src_w >= h + padding) {
|
||||
continue;
|
||||
}
|
||||
let src_w = src_w - padding;
|
||||
let src_idx = src_idx + src_w * src_s3;
|
||||
let dst_idx = dst_idx + w_k_idx;
|
||||
dst[dst_idx] = src[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||
|
||||
impl<'a> Map2 for Conv2D<'a> {
|
||||
@ -2237,7 +2314,43 @@ impl BackendStorage for CpuStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
Conv2D(params).map(self, l, kernel, kernel_l)
|
||||
if !USE_IM2COL_CONV2D {
|
||||
return Conv2D(params).map(self, l, kernel, kernel_l);
|
||||
}
|
||||
let op = Im2Col {
|
||||
h_k: params.k_h,
|
||||
w_k: params.k_w,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
};
|
||||
let col = op.map(self, l)?;
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let (h_out, w_out) = (params.out_h(), params.out_w());
|
||||
let k = op.h_k * op.w_k * params.c_in;
|
||||
let m = h_out * w_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
|
||||
.transpose(1, 2)?
|
||||
.transpose(1, 3)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
|
Reference in New Issue
Block a user