mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
im2col based conv2d (#802)
* im2col implementation for conv2d. * Fix for the im2col implementation to match the current conv2d. * Small optimization. * Add a cuda kernel. * Handle arbitrary layouts. * Im2Col cuda code.
This commit is contained in:
@ -9,6 +9,8 @@ use candle::quantized::GgmlType;
|
||||
use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
const CHECK_CONV2D: bool = false;
|
||||
|
||||
trait Benchmark {
|
||||
type PreProcessData;
|
||||
type RunResult;
|
||||
@ -19,25 +21,51 @@ trait Benchmark {
|
||||
const ITERS: usize;
|
||||
}
|
||||
|
||||
struct Im2Col(usize, usize);
|
||||
struct Im2Col {
|
||||
h_k: usize,
|
||||
w_k: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
impl Im2Col {
|
||||
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
|
||||
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
|
||||
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
|
||||
(h_out, w_out)
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::CustomOp1 for Im2Col {
|
||||
fn name(&self) -> &'static str {
|
||||
"im2col"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
let &Self(h_k, w_k) = self;
|
||||
let &Self {
|
||||
h_k,
|
||||
w_k,
|
||||
stride,
|
||||
dilation,
|
||||
padding,
|
||||
} = self;
|
||||
let (b, c, h, w) = layout.shape().dims4()?;
|
||||
let (h_out, w_out) = (h - h_k + 1, w - w_k + 1);
|
||||
let (h_out, w_out) = self.hw_out(h, w);
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let src = match layout.contiguous_offsets() {
|
||||
None => candle::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => &slice[o1..o2],
|
||||
};
|
||||
let src = &slice[layout.start_offset()..];
|
||||
let mut dst = vec![0f32; b * h_out * w_out * c * h_k * w_k];
|
||||
let (s_b, s_c, s_h) = (c * h * w, h * w, w);
|
||||
let (src_s0, src_s1, src_s2, src_s3) = {
|
||||
let s = layout.stride();
|
||||
(s[0], s[1], s[2], s[3])
|
||||
};
|
||||
// TODO: provide specialized kernels for the common use cases.
|
||||
// - h_k = w_k = 1
|
||||
// - padding = 0
|
||||
// - stride = 1
|
||||
// - dilation = 1
|
||||
for b_idx in 0..b {
|
||||
let src_idx = b_idx * s_b;
|
||||
let src_idx = b_idx * src_s0;
|
||||
let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
|
||||
for h_idx in 0..h_out {
|
||||
let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
|
||||
@ -45,12 +73,25 @@ impl candle::CustomOp1 for Im2Col {
|
||||
let dst_idx = dst_idx + w_idx * c * h_k * w_k;
|
||||
for c_idx in 0..c {
|
||||
let dst_idx = dst_idx + c_idx * h_k * w_k;
|
||||
let src_idx = c_idx * s_c + src_idx;
|
||||
let src_idx = c_idx * src_s1 + src_idx;
|
||||
for h_k_idx in 0..h_k {
|
||||
let src_idx = src_idx + (h_idx + h_k_idx) * s_h + w_idx;
|
||||
let src_h = h_idx * stride + h_k_idx * dilation;
|
||||
if padding != 0 && (src_h < padding || src_h >= h + padding) {
|
||||
continue;
|
||||
}
|
||||
let src_h = src_h - padding;
|
||||
let src_idx = src_idx + src_h * src_s2;
|
||||
let dst_idx = dst_idx + h_k_idx * w_k;
|
||||
dst[dst_idx..dst_idx + w_k]
|
||||
.copy_from_slice(&src[src_idx..src_idx + w_k])
|
||||
for w_k_idx in 0..w_k {
|
||||
let src_w = w_idx * stride + w_k_idx * dilation;
|
||||
if padding != 0 && (src_w < padding || src_w >= h + padding) {
|
||||
continue;
|
||||
}
|
||||
let src_w = src_w - padding;
|
||||
let src_idx = src_idx + src_w * src_s3;
|
||||
let dst_idx = dst_idx + w_k_idx;
|
||||
dst[dst_idx] = src[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -113,14 +154,26 @@ impl Benchmark for Conv2dIm2Col {
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
// d.0.conv2d(&d.1, 0, 1, 1, 1)
|
||||
let (b, _, h, w) = d.0.dims4()?;
|
||||
let (h_k, w_k) = (3, 3);
|
||||
let (h_out, w_out) = (h - h_k + 1, w - w_k + 1);
|
||||
let col = d.0.apply_op1_no_bwd(&Im2Col(h_k, w_k))?;
|
||||
let (_, _, h_k, w_k) = d.1.dims4()?;
|
||||
let op = Im2Col {
|
||||
h_k,
|
||||
w_k,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
padding: 0,
|
||||
};
|
||||
let (h_out, w_out) = op.hw_out(h, w);
|
||||
let col = d.0.apply_op1_no_bwd(&op)?;
|
||||
let res = col.matmul(&d.1.flatten_from(1)?.t()?)?;
|
||||
let res = res
|
||||
.reshape((b, h_out, w_out, ()))?
|
||||
.permute((0, 3, 1, 2))?
|
||||
.contiguous()?;
|
||||
if CHECK_CONV2D {
|
||||
let res2 = d.0.conv2d(&d.1, op.padding, op.stride, op.dilation, 1);
|
||||
let diff = (&res - res2)?.sqr()?.mean_all()?;
|
||||
println!("{diff}");
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user