mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
More efficient cuda implementation for ConvTranspose1d. (#2211)
* More efficient cuda implementation for ConvTranspose1d. * Small tweak.
This commit is contained in:
@ -10,7 +10,7 @@ pub use utils::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
const USE_IM2COL_CONV1D: bool = true;
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
const USE_IM2COL_CONV1D_TR: bool = true;
|
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||||
const USE_IM2COL_CONV2D: bool = true;
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||||
@ -2249,7 +2249,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
&& params.dilation == 1
|
&& params.dilation == 1
|
||||||
&& params.padding == 0
|
&& params.padding == 0
|
||||||
&& params.output_padding == 0;
|
&& params.output_padding == 0;
|
||||||
if USE_IM2COL_CONV1D_TR && can_use_col2im {
|
if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||||
if !kernel_l.is_contiguous() {
|
if !kernel_l.is_contiguous() {
|
||||||
|
@ -630,6 +630,31 @@ impl<'a> Map2 for Conv2D<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Col2Im1D {
|
||||||
|
stride: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Map1 for Col2Im1D {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
col: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
l: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
|
||||||
|
let stride = self.stride;
|
||||||
|
let l_out = (l_in - 1) * stride + k_size;
|
||||||
|
let dst_el = b_size * c_out * l_out;
|
||||||
|
let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
|
|
||||||
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
|
let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(im)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
@ -1366,9 +1391,55 @@ impl BackendStorage for CudaStorage {
|
|||||||
kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
params: &crate::conv::ParamsConvTranspose1D,
|
params: &crate::conv::ParamsConvTranspose1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||||
|
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice =
|
let can_use_col2im = kernel_l.is_contiguous()
|
||||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
&& params.dilation == 1
|
||||||
|
&& params.padding == 0
|
||||||
|
&& params.output_padding == 0;
|
||||||
|
let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||||
|
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||||
|
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||||
|
if !kernel_l.is_contiguous() {
|
||||||
|
crate::bail!(
|
||||||
|
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if c_in != c_in2 {
|
||||||
|
crate::bail!(
|
||||||
|
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||||
|
l.shape(),
|
||||||
|
kernel_l.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let col = {
|
||||||
|
// This merges the last two dimensions of the kernel together.
|
||||||
|
let kernel_l_mm = Layout::new(
|
||||||
|
(b_size, c_in, k_size * c_out).into(),
|
||||||
|
vec![0, k_size * c_out, 1],
|
||||||
|
kernel_l.start_offset(),
|
||||||
|
);
|
||||||
|
self.matmul(
|
||||||
|
kernel,
|
||||||
|
(
|
||||||
|
b_size,
|
||||||
|
/* m */ l_in,
|
||||||
|
/* n */ c_out * k_size,
|
||||||
|
/* k */ c_in,
|
||||||
|
),
|
||||||
|
&l.transpose(1, 2)?,
|
||||||
|
&kernel_l_mm,
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||||
|
Col2Im1D {
|
||||||
|
stride: params.stride,
|
||||||
|
}
|
||||||
|
.map(&col.slice, &device, &col_l)?
|
||||||
|
} else {
|
||||||
|
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?
|
||||||
|
};
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,6 +97,50 @@ __device__ void im2col1d(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void col2im1d(
|
||||||
|
const size_t dst_el,
|
||||||
|
const size_t l_out,
|
||||||
|
const size_t l_in,
|
||||||
|
const size_t c_out,
|
||||||
|
const size_t k_size,
|
||||||
|
const size_t stride,
|
||||||
|
const T *src,
|
||||||
|
T *dst
|
||||||
|
) {
|
||||||
|
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
// src: (b_size, l_in, c_out, l_k)
|
||||||
|
// dst: (b_size, c_out, l_out)
|
||||||
|
if (dst_i >= dst_el) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t dst_s0 = c_out * l_out;
|
||||||
|
const size_t dst_s1 = l_out;
|
||||||
|
const size_t src_s0 = c_out * k_size * l_in;
|
||||||
|
const size_t src_s1 = c_out * k_size;
|
||||||
|
const size_t src_s2 = k_size;
|
||||||
|
|
||||||
|
size_t tmp_dst_i = dst_i;
|
||||||
|
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||||
|
tmp_dst_i -= b_idx * dst_s0;
|
||||||
|
const size_t c_idx = tmp_dst_i / dst_s1;
|
||||||
|
tmp_dst_i -= c_idx * dst_s1;
|
||||||
|
const int l_out_idx = tmp_dst_i;
|
||||||
|
|
||||||
|
dst[dst_i] = static_cast<T>(0);
|
||||||
|
|
||||||
|
int l_in_idx = l_out_idx / stride;
|
||||||
|
int k0 = l_out_idx - l_in_idx * stride;
|
||||||
|
// l_out_idx = l_in_idx * stride + k0
|
||||||
|
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
|
||||||
|
if (l_in_idx < l_in) {
|
||||||
|
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
|
||||||
|
dst[dst_i] += src[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ void im2col(
|
__device__ void im2col(
|
||||||
const size_t dst_numel,
|
const size_t dst_numel,
|
||||||
@ -542,6 +586,20 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#define COL2IM1D_OP(TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const size_t dst_el, \
|
||||||
|
const size_t l_out, \
|
||||||
|
const size_t l_in, \
|
||||||
|
const size_t c_out, \
|
||||||
|
const size_t k_size, \
|
||||||
|
const size_t stride, \
|
||||||
|
const TYPENAME *src, \
|
||||||
|
TYPENAME *dst \
|
||||||
|
) { \
|
||||||
|
col2im1d<TYPENAME>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst); \
|
||||||
|
} \
|
||||||
|
|
||||||
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
const size_t dst_numel, \
|
const size_t dst_numel, \
|
||||||
@ -643,6 +701,7 @@ MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
|||||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||||
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
||||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||||
|
COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
@ -655,6 +714,7 @@ MAX_POOL2D_OP(__half, max_pool2d_f16)
|
|||||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||||
IM2COL_OP(__half, im2col_f16)
|
IM2COL_OP(__half, im2col_f16)
|
||||||
IM2COL1D_OP(__half, im2col1d_f16)
|
IM2COL1D_OP(__half, im2col1d_f16)
|
||||||
|
COL2IM1D_OP(__half, col2im1d_f16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
CONV1D_OP(float, float, conv1d_f32)
|
CONV1D_OP(float, float, conv1d_f32)
|
||||||
@ -701,3 +761,8 @@ IM2COL1D_OP(float, im2col1d_f32)
|
|||||||
IM2COL1D_OP(double, im2col1d_f64)
|
IM2COL1D_OP(double, im2col1d_f64)
|
||||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||||
|
|
||||||
|
COL2IM1D_OP(float, col2im1d_f32)
|
||||||
|
COL2IM1D_OP(double, col2im1d_f64)
|
||||||
|
COL2IM1D_OP(uint8_t, col2im1d_u8)
|
||||||
|
COL2IM1D_OP(uint32_t, col2im1d_u32)
|
||||||
|
Reference in New Issue
Block a user