mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Optimize the cuda conv transpose1d kernel.
This commit is contained in:
@ -608,6 +608,34 @@ impl Map1 for Elu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Col2Im1D {
|
||||||
|
stride: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Map1 for Col2Im1D {
|
||||||
|
fn f<T: DeviceRepr + WithDType>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let (b_size, l_in, c_out, k_size) = layout.shape().dims4()?;
|
||||||
|
let stride = self.stride;
|
||||||
|
let l_out = (l_in - 1) * stride + k_size;
|
||||||
|
|
||||||
|
let dst_el = b_size * c_out * l_out;
|
||||||
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
|
let src = &src.slice(layout.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
|
let params = (l_in, l_out, c_out, k_size, b_size, stride, src, &dst);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Im2Col1D {
|
struct Im2Col1D {
|
||||||
l_k: usize,
|
l_k: usize,
|
||||||
stride: usize,
|
stride: usize,
|
||||||
@ -1865,9 +1893,55 @@ impl BackendStorage for CudaStorage {
|
|||||||
params: &crate::conv::ParamsConvTranspose1D,
|
params: &crate::conv::ParamsConvTranspose1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice =
|
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
|
||||||
Ok(Self { slice, device })
|
let can_use_col2im = kernel_l.is_contiguous()
|
||||||
|
&& params.dilation == 1
|
||||||
|
&& params.padding == 0
|
||||||
|
&& params.output_padding == 0;
|
||||||
|
if !can_use_col2im || !USE_COL2IM_CONV1D_TR {
|
||||||
|
let slice =
|
||||||
|
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||||
|
return Ok(Self { slice, device });
|
||||||
|
}
|
||||||
|
|
||||||
|
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||||
|
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||||
|
if !kernel_l.is_contiguous() {
|
||||||
|
crate::bail!("convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}")
|
||||||
|
}
|
||||||
|
if c_in != c_in2 {
|
||||||
|
crate::bail!(
|
||||||
|
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||||
|
l.shape(),
|
||||||
|
kernel_l.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let col = {
|
||||||
|
// This merges the last two dimensions of the kernel together.
|
||||||
|
let kernel_l_mm = Layout::new(
|
||||||
|
(b_size, c_in, k_size * c_out).into(),
|
||||||
|
vec![0, k_size * c_out, 1],
|
||||||
|
kernel_l.start_offset(),
|
||||||
|
);
|
||||||
|
self.matmul(
|
||||||
|
kernel,
|
||||||
|
(
|
||||||
|
b_size,
|
||||||
|
/* m */ l_in,
|
||||||
|
/* n */ c_out * k_size,
|
||||||
|
/* k */ c_in,
|
||||||
|
),
|
||||||
|
&l.transpose(1, 2)?,
|
||||||
|
&kernel_l_mm,
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||||
|
Col2Im1D {
|
||||||
|
stride: params.stride,
|
||||||
|
}
|
||||||
|
.map(&col.slice, &device, &col_l)?;
|
||||||
|
Ok(col)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cudnn"))]
|
#[cfg(not(feature = "cudnn"))]
|
||||||
|
@ -51,6 +51,47 @@ __device__ void conv1d(
|
|||||||
dst[dst_i] = static_cast<T>(d);
|
dst[dst_i] = static_cast<T>(d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void col2im1d(
|
||||||
|
const size_t l_in,
|
||||||
|
const size_t l_out,
|
||||||
|
const size_t c_out,
|
||||||
|
const size_t k_size,
|
||||||
|
const size_t b_size,
|
||||||
|
const size_t stride,
|
||||||
|
const T *src,
|
||||||
|
T *dst
|
||||||
|
) {
|
||||||
|
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
// src: (b_size, l_in, c_out, k_size)
|
||||||
|
// dst: (b_size, c_out, l_out)
|
||||||
|
if (dst_i >= b_size * c_out * l_out) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const size_t dst_s0 = c_out * l_out;
|
||||||
|
const size_t dst_s1 = l_out;
|
||||||
|
|
||||||
|
// dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_in_i * stride + k_i
|
||||||
|
const size_t b_i = dst_i / dst_s0;
|
||||||
|
const size_t dst_i2 = dst_i - b_i * dst_s0;
|
||||||
|
const size_t c_i = dst_i2 / dst_s1;
|
||||||
|
const size_t dst_i3 = dst_i2 - c_i * dst_s1; // l_in_i * stride + k_i
|
||||||
|
|
||||||
|
const size_t src_s0 = c_out * k_size * l_in;
|
||||||
|
const size_t src_s1 = c_out * k_size;
|
||||||
|
const size_t src_s2 = k_size;
|
||||||
|
|
||||||
|
dst[dst_i] = 0;
|
||||||
|
for (size_t k_i = 0; k_i < min(dst_i3, k_size); ++k_i) {
|
||||||
|
const size_t l_in_i_times_stride = dst_i3 - k_i;
|
||||||
|
const size_t l_in_i = l_in_i_times_stride / stride;
|
||||||
|
const size_t src_i = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
|
||||||
|
if (l_in_i * stride == l_in_i_times_stride) {
|
||||||
|
dst[dst_i] += src[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ void im2col1d(
|
__device__ void im2col1d(
|
||||||
const size_t dst_numel,
|
const size_t dst_numel,
|
||||||
|
Reference in New Issue
Block a user