mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Cuda kernel for the conv1d op (#111)
* Boilerplate code for conv1d. * Boilerplate code for conv1d. * More boilerplate for conv1d. * Conv1d work. * Get the conv1d cuda kernel to work. * Conv1d support when no batch dim.
This commit is contained in:
@ -505,6 +505,46 @@ impl<'a> Map1 for Embedding<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
||||
impl<'a> Map2 for Conv1D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
inp_l: &Layout,
|
||||
k: &CudaSlice<T>,
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_out, c_in_k, k_size)
|
||||
// Input shape: (b_size, c_in, l_in) or (c_in, l_in)
|
||||
let p = &self.0;
|
||||
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let l_out = p.l_out();
|
||||
let dst_el = p.c_out * l_out * p.b_size.unwrap_or(1);
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }?;
|
||||
let ds = if dims.len() == 3 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else if dims.len() == 2 {
|
||||
[&[1], dims, &[1], inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds)?;
|
||||
let params = (el, l_out, p.stride, &ds, inp, k, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
|
||||
impl<'a> Map2 for WhereCond<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -854,12 +894,14 @@ impl CudaStorage {
|
||||
|
||||
pub(crate) fn conv1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
let device = self.device().clone();
|
||||
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
|
74
candle-kernels/src/conv.cu
Normal file
74
candle-kernels/src/conv.cu
Normal file
@ -0,0 +1,74 @@
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
template <typename T>
|
||||
__device__ void conv1d(
|
||||
const size_t src_numel,
|
||||
const size_t l_out,
|
||||
const size_t stride,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
T *dst
|
||||
) {
|
||||
// src: (b_size, c_in, l_in)
|
||||
// k: (c_out, c_in, k_size)
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 3;
|
||||
const size_t *k_dims = info + 6;
|
||||
const size_t *k_s = info + 9;
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t k_size = k_dims[2];
|
||||
const size_t k_over_2 = k_size / 2;
|
||||
const size_t c_out = k_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
|
||||
// TODO
|
||||
const size_t b_idx = dst_i / (l_out * c_out);
|
||||
const size_t dst_c_idx = (dst_i / l_out) % c_out;
|
||||
const size_t dst_l = dst_i % l_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
T d = 0;
|
||||
for (size_t offset = 0; offset < k_size; ++offset) {
|
||||
const size_t src_l_plus = stride * dst_l + offset;
|
||||
if (k_over_2 <= src_l_plus && src_l_plus < l_in + k_over_2) {
|
||||
const size_t src_l = src_l_plus - k_over_2;
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2];
|
||||
const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2];
|
||||
d += src[src_idx] * kernel[k_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[dst_i] = d;
|
||||
}
|
||||
|
||||
|
||||
#define CONV1D_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t stride, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv1d(src_numel, num_dims, stride, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
CONV1D_OP(__nv_bfloat16, conv1d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
CONV1D_OP(__half, conv1d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, conv1d_f32)
|
||||
CONV1D_OP(double, conv1d_f64)
|
||||
CONV1D_OP(uint8_t, conv1d_u8)
|
||||
CONV1D_OP(uint32_t, conv1d_u32)
|
||||
|
@ -1,6 +1,7 @@
|
||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||
pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
|
Reference in New Issue
Block a user