mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Support for cudnn conv1d. (#2888)
* Support for cudnn conv1d. * More conv1d work. * Get the conv1d to work with cudnn. * Cleanup.
This commit is contained in:
@ -14,6 +14,7 @@ pub struct ParamsConv1D {
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl ParamsConv1D {
|
||||
@ -174,6 +175,7 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv1d_single_group(kernel, ¶ms)
|
||||
|
@ -122,3 +122,104 @@ pub(crate) fn launch_conv2d<
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn launch_conv1d<
|
||||
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
||||
Y: cudarc::cudnn::CudnnDataType,
|
||||
>(
|
||||
src: &CudaView<T>,
|
||||
src_l: &crate::Layout,
|
||||
filter: &CudaView<T>,
|
||||
dst: &mut CudaSlice<T>,
|
||||
params: &crate::conv::ParamsConv1D,
|
||||
dev: &crate::cuda_backend::CudaDevice,
|
||||
) -> crate::Result<()> {
|
||||
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||
|
||||
let device_id = dev.id();
|
||||
let cudnn = CUDNN.with(|cudnn| {
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
return Ok(cudnn.clone());
|
||||
}
|
||||
let c = Cudnn::new(dev.cuda_stream());
|
||||
if let Ok(c) = &c {
|
||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||
}
|
||||
c
|
||||
})?;
|
||||
let conv = cudnn.create_conv2d::<Y>(
|
||||
/* pad */ [params.padding as i32, 0],
|
||||
/* stride */ [params.stride as i32, 1],
|
||||
/* dilation */ [params.dilation as i32, 1],
|
||||
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
||||
)?;
|
||||
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/api/cudnn-ops-library.html#cudnnsettensornddescriptor
|
||||
// > Tensors are restricted to having at least 4 dimensions, and at most CUDNN_DIM_MAX
|
||||
// > dimensions (defined in cudnn.h). When working with lower dimensional data, it is
|
||||
// > recommended that the user create a 4D tensor, and set the size along unused dimensions
|
||||
// > to 1.
|
||||
let x_shape = [
|
||||
params.b_size as i32,
|
||||
params.c_in as i32,
|
||||
params.l_in as i32,
|
||||
1,
|
||||
];
|
||||
// Note that `src` already starts at the proper offset.
|
||||
let x = if src_l.is_contiguous() {
|
||||
cudnn.create_4d_tensor::<T>(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
x_shape,
|
||||
)?
|
||||
} else {
|
||||
let s = src_l.stride();
|
||||
cudnn.create_4d_tensor_ex::<T>(x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, 1i32])?
|
||||
};
|
||||
let w = cudnn.create_4d_filter::<T>(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[
|
||||
params.c_out as i32,
|
||||
params.c_in as i32,
|
||||
params.k_size as i32,
|
||||
1,
|
||||
],
|
||||
)?;
|
||||
let l_out = params.l_out() as i32;
|
||||
let y = cudnn.create_4d_tensor::<T>(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[params.b_size as i32, params.c_out as i32, l_out, 1],
|
||||
)?;
|
||||
let conv1d = ConvForward {
|
||||
conv: &conv,
|
||||
x: &x,
|
||||
w: &w,
|
||||
y: &y,
|
||||
};
|
||||
let alg = match params.cudnn_fwd_algo {
|
||||
None => conv1d.pick_algorithm()?,
|
||||
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||
}
|
||||
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let workspace_size = conv1d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
conv1d.launch::<CudaSlice<u8>, _, _, _>(
|
||||
alg,
|
||||
Some(&mut workspace),
|
||||
(T::one(), T::zero()),
|
||||
src,
|
||||
filter,
|
||||
dst,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -134,6 +134,7 @@ impl Map1 for Elu {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
struct Im2Col1D {
|
||||
l_k: usize,
|
||||
stride: usize,
|
||||
@ -142,6 +143,7 @@ struct Im2Col1D {
|
||||
}
|
||||
|
||||
impl Im2Col1D {
|
||||
#[allow(unused)]
|
||||
fn l_out(&self, l: usize) -> usize {
|
||||
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||
}
|
||||
@ -1435,6 +1437,7 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cudnn"))]
|
||||
fn conv1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
@ -1485,6 +1488,72 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
#[cfg(feature = "cudnn")]
|
||||
fn conv1d(
|
||||
&self,
|
||||
inp_l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
if !kernel_l.is_contiguous() {
|
||||
let slice = Conv1D(params).map(&self.slice, inp_l, &kernel.slice, kernel_l, &device)?;
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
let l_out = params.l_out();
|
||||
let dst_el = params.c_out * l_out * params.b_size;
|
||||
let slice = match (&self.slice, &kernel.slice) {
|
||||
(S::U8(inp), S::U8(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<u8>(dst_el)? };
|
||||
crate::cudnn::launch_conv1d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::U8(out)
|
||||
}
|
||||
(S::BF16(inp), S::BF16(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<bf16>(dst_el)? };
|
||||
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
|
||||
// version.
|
||||
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
|
||||
crate::cudnn::launch_conv1d::<bf16, f32>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::BF16(out)
|
||||
}
|
||||
(S::F16(inp), S::F16(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f16>(dst_el)? };
|
||||
crate::cudnn::launch_conv1d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F16(out)
|
||||
}
|
||||
(S::F32(inp), S::F32(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f32>(dst_el)? };
|
||||
crate::cudnn::launch_conv1d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F32(out)
|
||||
}
|
||||
(S::F64(inp), S::F64(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f64>(dst_el)? };
|
||||
crate::cudnn::launch_conv1d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F64(out)
|
||||
}
|
||||
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv1d does not support u32"))?,
|
||||
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv1d does not support i64"))?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in conv1d"))?,
|
||||
};
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
|
Reference in New Issue
Block a user