mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
im2col version of the conv1d kernel. (#815)
* im2col version of the cuda conv1d kernel. * im2col version of the conv1d cpu kernel.
This commit is contained in:
@ -4,6 +4,7 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
|||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
const USE_IM2COL_CONV2D: bool = true;
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||||
@ -1091,6 +1092,65 @@ impl<'a> Map2 for Conv1D<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Im2Col1D {
|
||||||
|
l_k: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
padding: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Im2Col1D {
|
||||||
|
fn l_out(&self, l: usize) -> usize {
|
||||||
|
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Map1 for Im2Col1D {
|
||||||
|
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||||
|
let &Self {
|
||||||
|
l_k,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
padding,
|
||||||
|
} = self;
|
||||||
|
let (b, c, l) = layout.shape().dims3()?;
|
||||||
|
let l_out = self.l_out(l);
|
||||||
|
let src = &vs[layout.start_offset()..];
|
||||||
|
let mut dst = vec![T::zero(); b * l_out * c * l_k];
|
||||||
|
let (src_s0, src_s1, src_s2) = {
|
||||||
|
let s = layout.stride();
|
||||||
|
(s[0], s[1], s[2])
|
||||||
|
};
|
||||||
|
// TODO: provide specialized kernels for the common use cases.
|
||||||
|
// - l_k = 1
|
||||||
|
// - padding = 0
|
||||||
|
// - stride = 1
|
||||||
|
// - dilation = 1
|
||||||
|
for b_idx in 0..b {
|
||||||
|
let src_idx = b_idx * src_s0;
|
||||||
|
let dst_idx = b_idx * l_out * c * l_k;
|
||||||
|
for l_idx in 0..l_out {
|
||||||
|
let dst_idx = dst_idx + l_idx * c * l_k;
|
||||||
|
for c_idx in 0..c {
|
||||||
|
let dst_idx = dst_idx + c_idx * l_k;
|
||||||
|
let src_idx = c_idx * src_s1 + src_idx;
|
||||||
|
for l_k_idx in 0..l_k {
|
||||||
|
let src_l = l_idx * stride + l_k_idx * dilation;
|
||||||
|
if padding != 0 && (src_l < padding || src_l >= l + padding) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let src_l = src_l - padding;
|
||||||
|
let src_idx = src_idx + src_l * src_s2;
|
||||||
|
let dst_idx = dst_idx + l_k_idx;
|
||||||
|
dst[dst_idx] = src[src_idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Im2Col {
|
struct Im2Col {
|
||||||
h_k: usize,
|
h_k: usize,
|
||||||
w_k: usize,
|
w_k: usize,
|
||||||
@ -2305,7 +2365,40 @@ impl BackendStorage for CpuStorage {
|
|||||||
kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
params: &crate::conv::ParamsConv1D,
|
params: &crate::conv::ParamsConv1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Conv1D(params).map(self, l, kernel, kernel_l)
|
if !USE_IM2COL_CONV1D {
|
||||||
|
return Conv1D(params).map(self, l, kernel, kernel_l);
|
||||||
|
}
|
||||||
|
let op = Im2Col1D {
|
||||||
|
l_k: params.k_size,
|
||||||
|
padding: params.padding,
|
||||||
|
stride: params.stride,
|
||||||
|
dilation: params.dilation,
|
||||||
|
};
|
||||||
|
let col = op.map(self, l)?;
|
||||||
|
let b = params.b_size;
|
||||||
|
let n = params.c_out;
|
||||||
|
let l_out = params.l_out();
|
||||||
|
let k = op.l_k * params.c_in;
|
||||||
|
let m = l_out;
|
||||||
|
let col_l = Layout::contiguous((b, m, k));
|
||||||
|
let res = if kernel_l.is_contiguous() {
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
} else {
|
||||||
|
// Make the kernel contiguous if not already the case.
|
||||||
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
};
|
||||||
|
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
|
||||||
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
|
@ -11,6 +11,7 @@ use cudarc::driver::{
|
|||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
const USE_IM2COL_CONV2D: bool = true;
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
/// cudarc related errors
|
/// cudarc related errors
|
||||||
@ -602,6 +603,53 @@ impl Map1 for Elu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Im2Col1D {
|
||||||
|
l_k: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
padding: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Im2Col1D {
|
||||||
|
fn l_out(&self, l: usize) -> usize {
|
||||||
|
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Map1 for Im2Col1D {
|
||||||
|
fn f<T: DeviceRepr + WithDType>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let shape = layout.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let l_out = self.l_out(dims[2]);
|
||||||
|
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
|
||||||
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
|
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||||
|
let src = &src.slice(layout.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
|
let params = (
|
||||||
|
dst_el,
|
||||||
|
l_out,
|
||||||
|
self.l_k,
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
&ds,
|
||||||
|
src,
|
||||||
|
&dst,
|
||||||
|
);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Im2Col {
|
struct Im2Col {
|
||||||
h_k: usize,
|
h_k: usize,
|
||||||
w_k: usize,
|
w_k: usize,
|
||||||
@ -1712,8 +1760,43 @@ impl BackendStorage for CudaStorage {
|
|||||||
params: &crate::conv::ParamsConv1D,
|
params: &crate::conv::ParamsConv1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
if !USE_IM2COL_CONV1D {
|
||||||
Ok(Self { slice, device })
|
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||||
|
return Ok(Self { slice, device });
|
||||||
|
}
|
||||||
|
|
||||||
|
let col = Im2Col1D {
|
||||||
|
l_k: params.k_size,
|
||||||
|
stride: params.stride,
|
||||||
|
dilation: params.dilation,
|
||||||
|
padding: params.padding,
|
||||||
|
}
|
||||||
|
.map(&self.slice, &device, l)?;
|
||||||
|
let col = Self { slice: col, device };
|
||||||
|
let l_out = params.l_out();
|
||||||
|
let b = params.b_size;
|
||||||
|
let n = params.c_out;
|
||||||
|
let k = params.k_size * params.c_in;
|
||||||
|
let m = l_out;
|
||||||
|
let col_l = Layout::contiguous((b, m, k));
|
||||||
|
let res = if kernel_l.is_contiguous() {
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
} else {
|
||||||
|
// Make the kernel contiguous if not already the case.
|
||||||
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
};
|
||||||
|
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||||
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cudnn"))]
|
#[cfg(not(feature = "cudnn"))]
|
||||||
|
@ -51,6 +51,53 @@ __device__ void conv1d(
|
|||||||
dst[dst_i] = static_cast<T>(d);
|
dst[dst_i] = static_cast<T>(d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void im2col1d(
|
||||||
|
const size_t dst_numel,
|
||||||
|
const size_t l_out,
|
||||||
|
const size_t l_k,
|
||||||
|
const size_t stride,
|
||||||
|
const size_t padding,
|
||||||
|
const size_t dilation,
|
||||||
|
const size_t *info,
|
||||||
|
const T *src,
|
||||||
|
T *dst
|
||||||
|
) {
|
||||||
|
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
// dst: (b_size, l_out, c_in, l_k)
|
||||||
|
// src: (b_size, c_in, l_in)
|
||||||
|
if (dst_i >= dst_numel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const size_t *src_dims = info;
|
||||||
|
const size_t *src_s = info + 3;
|
||||||
|
const size_t b_in = src_dims[0];
|
||||||
|
const size_t c_in = src_dims[1];
|
||||||
|
const size_t l_in = src_dims[2];
|
||||||
|
|
||||||
|
const size_t dst_s2 = l_k;
|
||||||
|
const size_t dst_s1 = c_in * dst_s2;
|
||||||
|
const size_t dst_s0 = l_out * dst_s1;
|
||||||
|
|
||||||
|
size_t tmp_dst_i = dst_i;
|
||||||
|
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||||
|
tmp_dst_i -= b_idx * dst_s0;
|
||||||
|
const size_t l_idx = tmp_dst_i / dst_s1;
|
||||||
|
tmp_dst_i -= l_idx * dst_s1;
|
||||||
|
const size_t c_idx = tmp_dst_i / dst_s2;
|
||||||
|
tmp_dst_i -= c_idx * dst_s2;
|
||||||
|
const size_t l_k_idx = tmp_dst_i;
|
||||||
|
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
|
||||||
|
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
|
||||||
|
dst[dst_i] = static_cast<T>(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
src_l_idx -= padding;
|
||||||
|
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
|
||||||
|
dst[dst_i] = src[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ void im2col(
|
__device__ void im2col(
|
||||||
const size_t dst_numel,
|
const size_t dst_numel,
|
||||||
@ -78,7 +125,7 @@ __device__ void im2col(
|
|||||||
const size_t h_in = src_dims[2];
|
const size_t h_in = src_dims[2];
|
||||||
const size_t w_in = src_dims[3];
|
const size_t w_in = src_dims[3];
|
||||||
|
|
||||||
const size_t dst_s4 = w_k;
|
const size_t dst_s4 = w_k;
|
||||||
const size_t dst_s3 = h_k * dst_s4;
|
const size_t dst_s3 = h_k * dst_s4;
|
||||||
const size_t dst_s2 = c_in * dst_s3;
|
const size_t dst_s2 = c_in * dst_s3;
|
||||||
const size_t dst_s1 = w_out * dst_s2;
|
const size_t dst_s1 = w_out * dst_s2;
|
||||||
@ -428,6 +475,21 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
|
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#define IM2COL1D_OP(TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const size_t dst_numel, \
|
||||||
|
const size_t l_out, \
|
||||||
|
const size_t l_k, \
|
||||||
|
const size_t stride, \
|
||||||
|
const size_t padding, \
|
||||||
|
const size_t dilation, \
|
||||||
|
const size_t *info, \
|
||||||
|
const TYPENAME *src, \
|
||||||
|
TYPENAME *dst \
|
||||||
|
) { \
|
||||||
|
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
||||||
|
} \
|
||||||
|
|
||||||
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
const size_t dst_numel, \
|
const size_t dst_numel, \
|
||||||
@ -511,6 +573,7 @@ AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
|||||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||||
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
||||||
|
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
@ -521,6 +584,7 @@ AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
|||||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||||
IM2COL_OP(__half, im2col_f16)
|
IM2COL_OP(__half, im2col_f16)
|
||||||
|
IM2COL1D_OP(__half, im2col1d_f16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
CONV1D_OP(float, float, conv1d_f32)
|
CONV1D_OP(float, float, conv1d_f32)
|
||||||
@ -557,3 +621,8 @@ IM2COL_OP(float, im2col_f32)
|
|||||||
IM2COL_OP(double, im2col_f64)
|
IM2COL_OP(double, im2col_f64)
|
||||||
IM2COL_OP(uint8_t, im2col_u8)
|
IM2COL_OP(uint8_t, im2col_u8)
|
||||||
IM2COL_OP(uint32_t, im2col_u32)
|
IM2COL_OP(uint32_t, im2col_u32)
|
||||||
|
|
||||||
|
IM2COL1D_OP(float, im2col1d_f32)
|
||||||
|
IM2COL1D_OP(double, im2col1d_f64)
|
||||||
|
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||||
|
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||||
|
Reference in New Issue
Block a user