mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
im2col version of the conv1d kernel. (#815)
* im2col version of the cuda conv1d kernel. * im2col version of the conv1d cpu kernel.
This commit is contained in:
@ -4,6 +4,7 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
use rayon::prelude::*;
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
@ -1091,6 +1092,65 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col1D {
|
||||
l_k: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
impl Im2Col1D {
|
||||
fn l_out(&self, l: usize) -> usize {
|
||||
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||
}
|
||||
}
|
||||
|
||||
impl Map1 for Im2Col1D {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let &Self {
|
||||
l_k,
|
||||
stride,
|
||||
dilation,
|
||||
padding,
|
||||
} = self;
|
||||
let (b, c, l) = layout.shape().dims3()?;
|
||||
let l_out = self.l_out(l);
|
||||
let src = &vs[layout.start_offset()..];
|
||||
let mut dst = vec![T::zero(); b * l_out * c * l_k];
|
||||
let (src_s0, src_s1, src_s2) = {
|
||||
let s = layout.stride();
|
||||
(s[0], s[1], s[2])
|
||||
};
|
||||
// TODO: provide specialized kernels for the common use cases.
|
||||
// - l_k = 1
|
||||
// - padding = 0
|
||||
// - stride = 1
|
||||
// - dilation = 1
|
||||
for b_idx in 0..b {
|
||||
let src_idx = b_idx * src_s0;
|
||||
let dst_idx = b_idx * l_out * c * l_k;
|
||||
for l_idx in 0..l_out {
|
||||
let dst_idx = dst_idx + l_idx * c * l_k;
|
||||
for c_idx in 0..c {
|
||||
let dst_idx = dst_idx + c_idx * l_k;
|
||||
let src_idx = c_idx * src_s1 + src_idx;
|
||||
for l_k_idx in 0..l_k {
|
||||
let src_l = l_idx * stride + l_k_idx * dilation;
|
||||
if padding != 0 && (src_l < padding || src_l >= l + padding) {
|
||||
continue;
|
||||
}
|
||||
let src_l = src_l - padding;
|
||||
let src_idx = src_idx + src_l * src_s2;
|
||||
let dst_idx = dst_idx + l_k_idx;
|
||||
dst[dst_idx] = src[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col {
|
||||
h_k: usize,
|
||||
w_k: usize,
|
||||
@ -2305,7 +2365,40 @@ impl BackendStorage for CpuStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
Conv1D(params).map(self, l, kernel, kernel_l)
|
||||
if !USE_IM2COL_CONV1D {
|
||||
return Conv1D(params).map(self, l, kernel, kernel_l);
|
||||
}
|
||||
let op = Im2Col1D {
|
||||
l_k: params.k_size,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
};
|
||||
let col = op.map(self, l)?;
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let l_out = params.l_out();
|
||||
let k = op.l_k * params.c_in;
|
||||
let m = l_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
|
@ -11,6 +11,7 @@ use cudarc::driver::{
|
||||
use half::{bf16, f16};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
/// cudarc related errors
|
||||
@ -602,6 +603,53 @@ impl Map1 for Elu {
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col1D {
|
||||
l_k: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
impl Im2Col1D {
|
||||
fn l_out(&self, l: usize) -> usize {
|
||||
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||
}
|
||||
}
|
||||
|
||||
impl Map1 for Im2Col1D {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let l_out = self.l_out(dims[2]);
|
||||
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (
|
||||
dst_el,
|
||||
l_out,
|
||||
self.l_k,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
&ds,
|
||||
src,
|
||||
&dst,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col {
|
||||
h_k: usize,
|
||||
w_k: usize,
|
||||
@ -1712,8 +1760,43 @@ impl BackendStorage for CudaStorage {
|
||||
params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
if !USE_IM2COL_CONV1D {
|
||||
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
|
||||
let col = Im2Col1D {
|
||||
l_k: params.k_size,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
padding: params.padding,
|
||||
}
|
||||
.map(&self.slice, &device, l)?;
|
||||
let col = Self { slice: col, device };
|
||||
let l_out = params.l_out();
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let k = params.k_size * params.c_in;
|
||||
let m = l_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cudnn"))]
|
||||
|
Reference in New Issue
Block a user