mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fast CPU kernel for transposed 1d convolutions. (#1822)
* Fast CPU kernel for transposed 1d convolutions. * Bugfix.
This commit is contained in:
@ -5,6 +5,7 @@ use half::{bf16, f16};
|
|||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
const USE_IM2COL_CONV1D: bool = true;
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
|
const USE_IM2COL_CONV1D_TR: bool = true;
|
||||||
const USE_IM2COL_CONV2D: bool = true;
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||||
@ -1256,6 +1257,34 @@ impl Map1 for Im2Col {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Col2Im1D {
|
||||||
|
stride: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Map1 for Col2Im1D {
|
||||||
|
fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
|
||||||
|
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
|
||||||
|
let stride = self.stride;
|
||||||
|
let l_out = (l_in - 1) * stride + k_size;
|
||||||
|
let mut im = vec![T::zero(); b_size * c_out * l_out];
|
||||||
|
let (dst_s0, dst_s1) = (c_out * l_out, l_out);
|
||||||
|
let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
|
||||||
|
for l_in_i in 0..l_in {
|
||||||
|
for k_i in 0..k_size {
|
||||||
|
let l_out_i = l_in_i * stride + k_i;
|
||||||
|
for b_i in 0..b_size {
|
||||||
|
for c_i in 0..c_out {
|
||||||
|
let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
|
||||||
|
let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
|
||||||
|
im[dst_idx] += col[src_idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(im)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||||
|
|
||||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||||
@ -2511,7 +2540,52 @@ impl BackendStorage for CpuStorage {
|
|||||||
kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
params: &crate::conv::ParamsConvTranspose1D,
|
params: &crate::conv::ParamsConvTranspose1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
let can_use_col2im = kernel_l.is_contiguous()
|
||||||
|
&& params.dilation == 1
|
||||||
|
&& params.padding == 0
|
||||||
|
&& params.output_padding == 0;
|
||||||
|
if USE_IM2COL_CONV1D_TR && can_use_col2im {
|
||||||
|
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||||
|
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||||
|
if !kernel_l.is_contiguous() {
|
||||||
|
crate::bail!(
|
||||||
|
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if c_in != c_in2 {
|
||||||
|
crate::bail!(
|
||||||
|
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||||
|
l.shape(),
|
||||||
|
kernel_l.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let col = {
|
||||||
|
// This merges the last two dimensions of the kernel together.
|
||||||
|
let kernel_l_mm = Layout::new(
|
||||||
|
(b_size, c_in, k_size * c_out).into(),
|
||||||
|
vec![0, k_size * c_out, 1],
|
||||||
|
kernel_l.start_offset(),
|
||||||
|
);
|
||||||
|
self.matmul(
|
||||||
|
kernel,
|
||||||
|
(
|
||||||
|
b_size,
|
||||||
|
/* m */ l_in,
|
||||||
|
/* n */ c_out * k_size,
|
||||||
|
/* k */ c_in,
|
||||||
|
),
|
||||||
|
&l.transpose(1, 2)?,
|
||||||
|
&kernel_l_mm,
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||||
|
Col2Im1D {
|
||||||
|
stride: params.stride,
|
||||||
|
}
|
||||||
|
.map(&col, &col_l)
|
||||||
|
} else {
|
||||||
|
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
|
@ -53,26 +53,30 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||||
);
|
);
|
||||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
|
let w = w.transpose(0, 1)?;
|
||||||
assert_eq!(res.dims(), [1, 2, 7]);
|
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
||||||
assert_eq!(
|
for w in [w.clone(), w.contiguous()?] {
|
||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 1)?;
|
||||||
[
|
assert_eq!(res.dims(), [1, 2, 7]);
|
||||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
assert_eq!(
|
||||||
4.7076, -5.9745, -0.8276, 1.621
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
],
|
[
|
||||||
);
|
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 2)?;
|
4.7076, -5.9745, -0.8276, 1.621
|
||||||
assert_eq!(res.dims(), [1, 4, 7]);
|
],
|
||||||
assert_eq!(
|
);
|
||||||
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 2)?;
|
||||||
[
|
assert_eq!(res.dims(), [1, 4, 7]);
|
||||||
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
assert_eq!(
|
||||||
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
||||||
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
[
|
||||||
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
||||||
]
|
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
||||||
);
|
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
||||||
|
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,7 +120,7 @@ fn main() -> Result<()> {
|
|||||||
Some(w) => std::path::PathBuf::from(w),
|
Some(w) => std::path::PathBuf::from(w),
|
||||||
None => repo.get("first_stage.safetensors")?,
|
None => repo.get("first_stage.safetensors")?,
|
||||||
};
|
};
|
||||||
let second_stage_weights = match &args.first_stage_weights {
|
let second_stage_weights = match &args.second_stage_weights {
|
||||||
Some(w) => std::path::PathBuf::from(w),
|
Some(w) => std::path::PathBuf::from(w),
|
||||||
None => repo.get("second_stage.safetensors")?,
|
None => repo.get("second_stage.safetensors")?,
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user