im2col based conv2d (#802)

* im2col implementation for conv2d.

* Fix for the im2col implementation to match the current conv2d.

* Small optimization.

* Add a cuda kernel.

* Handle arbitrary layouts.

* Im2Col cuda code.
This commit is contained in:
Laurent Mazare
2023-09-10 21:02:42 +01:00
committed by GitHub
parent 18d6db2180
commit 98d1242b8f
3 changed files with 210 additions and 16 deletions

View File

@ -600,6 +600,58 @@ impl Map1 for Elu {
}
}
struct Im2Col {
h_k: usize,
w_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col {
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
(h_out, w_out)
}
}
impl Map1 for Im2Col {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (
dst_el,
h_out,
w_out,
self.h_k,
self.w_k,
self.stride,
self.padding,
self.dilation,
&ds,
src,
&dst,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
struct Powf(f64);
impl Map1 for Powf {
fn f<T: DeviceRepr + WithDType>(

View File

@ -51,6 +51,71 @@ __device__ void conv1d(
dst[dst_i] = static_cast<T>(d);
}
template <typename T>
__device__ void im2col(
const size_t dst_numel,
const size_t h_out,
const size_t w_out,
const size_t h_k,
const size_t w_k,
const size_t stride,
const size_t padding,
const size_t dilation,
const size_t *info,
const T *src,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// dst: (b_size, h_out, w_out, c_in, h_k, w_k)
// src: (b_size, c_in, h_in, w_in)
if (dst_i >= dst_numel) {
return;
}
const size_t *src_dims = info;
const size_t *src_s = info + 4;
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t h_in = src_dims[2];
const size_t w_in = src_dims[3];
const size_t dst_s4 = w_k;
const size_t dst_s3 = h_k * dst_s4;
const size_t dst_s2 = c_in * dst_s3;
const size_t dst_s1 = w_out * dst_s2;
const size_t dst_s0 = h_out * dst_s1;
size_t tmp_dst_i = dst_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t h_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= h_idx * dst_s1;
const size_t w_idx = tmp_dst_i / dst_s2;
tmp_dst_i -= w_idx * dst_s2;
const size_t c_idx = tmp_dst_i / dst_s3;
tmp_dst_i -= c_idx * dst_s3;
const size_t h_k_idx = tmp_dst_i / dst_s4;
tmp_dst_i -= h_k_idx * dst_s4;
const size_t w_k_idx = tmp_dst_i;
size_t src_h_idx = h_idx * stride + h_k_idx * dilation;
size_t src_w_idx = w_idx * stride + w_k_idx * dilation;
if (src_h_idx < padding || src_h_idx >= h_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else if (src_w_idx < padding || src_w_idx >= w_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else {
src_h_idx -= padding;
src_w_idx -= padding;
const size_t src_i =
b_idx * src_s[0]
+ c_idx * src_s[1]
+ src_h_idx * src_s[2]
+ src_w_idx * src_s[3];
dst[dst_i] = src[src_i];
}
}
// Naive implementation of conv2d.
template <typename T, typename A>
__device__ void conv2d(
@ -363,6 +428,23 @@ extern "C" __global__ void FN_NAME( \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
} \
#define IM2COL_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t dst_numel, \
const size_t h_out, \
const size_t w_out, \
const size_t h_k, \
const size_t w_k, \
const size_t stride, \
const size_t padding, \
const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
TYPENAME *dst \
) { \
im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \
} \
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t src_numel, \
@ -428,6 +510,7 @@ CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
IM2COL_OP(__nv_bfloat16, im2col_bf16)
#endif
#if __CUDA_ARCH__ >= 530
@ -437,6 +520,7 @@ CONVT2D_OP(__half, float, conv_transpose2d_f16)
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
MAX_POOL2D_OP(__half, max_pool2d_f16)
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
IM2COL_OP(__half, im2col_f16)
#endif
CONV1D_OP(float, float, conv1d_f32)
@ -468,3 +552,8 @@ UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64)
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
IM2COL_OP(float, im2col_f32)
IM2COL_OP(double, im2col_f64)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)

View File

@ -9,6 +9,8 @@ use candle::quantized::GgmlType;
use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
use clap::{Parser, Subcommand};
const CHECK_CONV2D: bool = false;
trait Benchmark {
type PreProcessData;
type RunResult;
@ -19,25 +21,51 @@ trait Benchmark {
const ITERS: usize;
}
struct Im2Col(usize, usize);
struct Im2Col {
h_k: usize,
w_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col {
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
(h_out, w_out)
}
}
impl candle::CustomOp1 for Im2Col {
fn name(&self) -> &'static str {
"im2col"
}
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
let &Self(h_k, w_k) = self;
let &Self {
h_k,
w_k,
stride,
dilation,
padding,
} = self;
let (b, c, h, w) = layout.shape().dims4()?;
let (h_out, w_out) = (h - h_k + 1, w - w_k + 1);
let (h_out, w_out) = self.hw_out(h, w);
let slice = storage.as_slice::<f32>()?;
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => &slice[o1..o2],
};
let src = &slice[layout.start_offset()..];
let mut dst = vec![0f32; b * h_out * w_out * c * h_k * w_k];
let (s_b, s_c, s_h) = (c * h * w, h * w, w);
let (src_s0, src_s1, src_s2, src_s3) = {
let s = layout.stride();
(s[0], s[1], s[2], s[3])
};
// TODO: provide specialized kernels for the common use cases.
// - h_k = w_k = 1
// - padding = 0
// - stride = 1
// - dilation = 1
for b_idx in 0..b {
let src_idx = b_idx * s_b;
let src_idx = b_idx * src_s0;
let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
for h_idx in 0..h_out {
let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
@ -45,12 +73,25 @@ impl candle::CustomOp1 for Im2Col {
let dst_idx = dst_idx + w_idx * c * h_k * w_k;
for c_idx in 0..c {
let dst_idx = dst_idx + c_idx * h_k * w_k;
let src_idx = c_idx * s_c + src_idx;
let src_idx = c_idx * src_s1 + src_idx;
for h_k_idx in 0..h_k {
let src_idx = src_idx + (h_idx + h_k_idx) * s_h + w_idx;
let src_h = h_idx * stride + h_k_idx * dilation;
if padding != 0 && (src_h < padding || src_h >= h + padding) {
continue;
}
let src_h = src_h - padding;
let src_idx = src_idx + src_h * src_s2;
let dst_idx = dst_idx + h_k_idx * w_k;
dst[dst_idx..dst_idx + w_k]
.copy_from_slice(&src[src_idx..src_idx + w_k])
for w_k_idx in 0..w_k {
let src_w = w_idx * stride + w_k_idx * dilation;
if padding != 0 && (src_w < padding || src_w >= h + padding) {
continue;
}
let src_w = src_w - padding;
let src_idx = src_idx + src_w * src_s3;
let dst_idx = dst_idx + w_k_idx;
dst[dst_idx] = src[src_idx]
}
}
}
}
@ -113,14 +154,26 @@ impl Benchmark for Conv2dIm2Col {
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
// d.0.conv2d(&d.1, 0, 1, 1, 1)
let (b, _, h, w) = d.0.dims4()?;
let (h_k, w_k) = (3, 3);
let (h_out, w_out) = (h - h_k + 1, w - w_k + 1);
let col = d.0.apply_op1_no_bwd(&Im2Col(h_k, w_k))?;
let (_, _, h_k, w_k) = d.1.dims4()?;
let op = Im2Col {
h_k,
w_k,
stride: 1,
dilation: 1,
padding: 0,
};
let (h_out, w_out) = op.hw_out(h, w);
let col = d.0.apply_op1_no_bwd(&op)?;
let res = col.matmul(&d.1.flatten_from(1)?.t()?)?;
let res = res
.reshape((b, h_out, w_out, ()))?
.permute((0, 3, 1, 2))?
.contiguous()?;
if CHECK_CONV2D {
let res2 = d.0.conv2d(&d.1, op.padding, op.stride, op.dilation, 1);
let diff = (&res - res2)?.sqr()?.mean_all()?;
println!("{diff}");
}
Ok(res)
}