Fast CPU kernel for transposed 1d convolutions. (#1822)

* Fast CPU kernel for transposed 1d convolutions.

* Bugfix.
This commit is contained in:
Laurent Mazare
2024-03-08 22:43:07 +01:00
committed by GitHub
parent e7fc1daa21
commit 3440cec3a0
3 changed files with 100 additions and 22 deletions

View File

@ -5,6 +5,7 @@ use half::{bf16, f16};
use rayon::prelude::*;
const USE_IM2COL_CONV1D: bool = true;
const USE_IM2COL_CONV1D_TR: bool = true;
const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
@ -1256,6 +1257,34 @@ impl Map1 for Im2Col {
}
}
struct Col2Im1D {
stride: usize,
}
impl Map1 for Col2Im1D {
fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
let stride = self.stride;
let l_out = (l_in - 1) * stride + k_size;
let mut im = vec![T::zero(); b_size * c_out * l_out];
let (dst_s0, dst_s1) = (c_out * l_out, l_out);
let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
for l_in_i in 0..l_in {
for k_i in 0..k_size {
let l_out_i = l_in_i * stride + k_i;
for b_i in 0..b_size {
for c_i in 0..c_out {
let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
im[dst_idx] += col[src_idx]
}
}
}
}
Ok(im)
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
@ -2511,7 +2540,52 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
let can_use_col2im = kernel_l.is_contiguous()
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
if USE_IM2COL_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {
crate::bail!(
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
)
}
if c_in != c_in2 {
crate::bail!(
"convtr1d: shape mismatch on c_in {:?} {:?}",
l.shape(),
kernel_l.shape()
)
}
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
kernel_l.start_offset(),
);
self.matmul(
kernel,
(
b_size,
/* m */ l_in,
/* n */ c_out * k_size,
/* k */ c_in,
),
&l.transpose(1, 2)?,
&kernel_l_mm,
)?
};
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
Col2Im1D {
stride: params.stride,
}
.map(&col, &col_l)
} else {
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
}
}
fn conv2d(