mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
6 Commits
0.9.0-alph
...
cuda-conv-
Author | SHA1 | Date | |
---|---|---|---|
53f951f6e2 | |||
52e70856ea | |||
3cae6f5e9a | |||
dffafd1049 | |||
75f2aea5fd | |||
42ae70c458 |
@ -608,6 +608,34 @@ impl Map1 for Elu {
|
||||
}
|
||||
}
|
||||
|
||||
struct Col2Im1D {
|
||||
stride: usize,
|
||||
}
|
||||
|
||||
impl Map1 for Col2Im1D {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let (b_size, l_in, c_out, k_size) = layout.shape().dims4()?;
|
||||
let stride = self.stride;
|
||||
let l_out = (l_in - 1) * stride + k_size;
|
||||
|
||||
let dst_el = b_size * c_out * l_out;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (l_in, l_out, c_out, k_size, b_size, stride, src, &dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col1D {
|
||||
l_k: usize,
|
||||
stride: usize,
|
||||
@ -1865,8 +1893,54 @@ impl BackendStorage for CudaStorage {
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||
|
||||
let can_use_col2im = kernel_l.is_contiguous()
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
if !can_use_col2im || !USE_COL2IM_CONV1D_TR {
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
|
||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||
if !kernel_l.is_contiguous() {
|
||||
crate::bail!("convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}")
|
||||
}
|
||||
if c_in != c_in2 {
|
||||
crate::bail!(
|
||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||
l.shape(),
|
||||
kernel_l.shape()
|
||||
)
|
||||
}
|
||||
let col = {
|
||||
// This merges the last two dimensions of the kernel together.
|
||||
let kernel_l_mm = Layout::new(
|
||||
(b_size, c_in, k_size * c_out).into(),
|
||||
vec![0, k_size * c_out, 1],
|
||||
kernel_l.start_offset(),
|
||||
);
|
||||
self.matmul(
|
||||
kernel,
|
||||
(
|
||||
b_size,
|
||||
/* m */ l_in,
|
||||
/* n */ c_out * k_size,
|
||||
/* k */ c_in,
|
||||
),
|
||||
&l.transpose(1, 2)?,
|
||||
&kernel_l_mm,
|
||||
)?
|
||||
};
|
||||
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||
let slice = Col2Im1D {
|
||||
stride: params.stride,
|
||||
}
|
||||
.map(&col.slice, &device, &col_l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
|
@ -51,6 +51,48 @@ __device__ void conv1d(
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void col2im1d(
|
||||
const size_t l_in,
|
||||
const size_t l_out,
|
||||
const size_t c_out,
|
||||
const size_t k_size,
|
||||
const size_t b_size,
|
||||
const size_t stride,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, l_in, c_out, k_size)
|
||||
// dst: (b_size, c_out, l_out)
|
||||
if (dst_i >= b_size * c_out * l_out) {
|
||||
return;
|
||||
}
|
||||
const size_t dst_s0 = c_out * l_out;
|
||||
const size_t dst_s1 = l_out;
|
||||
|
||||
// dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_in_i * stride + k_i
|
||||
const size_t b_i = dst_i / dst_s0;
|
||||
const size_t dst_i2 = dst_i - b_i * dst_s0;
|
||||
const size_t c_i = dst_i2 / dst_s1;
|
||||
const size_t dst_i3 = dst_i2 - c_i * dst_s1; // l_in_i * stride + k_i
|
||||
|
||||
const size_t src_s0 = c_out * k_size * l_in;
|
||||
const size_t src_s1 = c_out * k_size;
|
||||
const size_t src_s2 = k_size;
|
||||
|
||||
T d = 0;
|
||||
for (size_t k_i = 0; k_i < min(dst_i3 + 1, k_size); ++k_i) {
|
||||
const size_t l_in_i_times_stride = dst_i3 - k_i;
|
||||
const size_t l_in_i = l_in_i_times_stride / stride;
|
||||
const size_t src_i = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
|
||||
if (l_in_i * stride == l_in_i_times_stride && l_in_i < l_in) {
|
||||
d += src[src_i];
|
||||
}
|
||||
}
|
||||
dst[dst_i] = d;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void im2col1d(
|
||||
const size_t dst_numel,
|
||||
@ -527,7 +569,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL1D_OP(TYPENAME, FN_NAME) \
|
||||
#define IM2COL1D_OP(TYPENAME, FN_NAME, FN_NAME2) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t dst_numel, \
|
||||
const size_t l_out, \
|
||||
@ -541,6 +583,18 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME2( \
|
||||
const size_t l_in, \
|
||||
const size_t l_out, \
|
||||
const size_t c_out, \
|
||||
const size_t k_size, \
|
||||
const size_t b_size, \
|
||||
const size_t stride, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
col2im1d<TYPENAME>(l_in, l_out, c_out, k_size, b_size, stride, src, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
@ -642,7 +696,7 @@ AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16, col2im1d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -654,7 +708,7 @@ AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||
IM2COL_OP(__half, im2col_f16)
|
||||
IM2COL1D_OP(__half, im2col1d_f16)
|
||||
IM2COL1D_OP(__half, im2col1d_f16, col2im1d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
@ -697,7 +751,7 @@ IM2COL_OP(double, im2col_f64)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
IM2COL_OP(uint32_t, im2col_u32)
|
||||
|
||||
IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(double, im2col1d_f64)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
IM2COL1D_OP(float, im2col1d_f32, col2im1d_f32)
|
||||
IM2COL1D_OP(double, im2col1d_f64, col2im1d_f64)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8, col2im1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32, col2im1d_u32)
|
||||
|
Reference in New Issue
Block a user